From e55178abc21a3a692b7b18cc12922b4004c15f2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=9D=E8=80=B3?= Date: Thu, 30 Mar 2023 14:11:02 +0800 Subject: [PATCH] general punc model runtime --- .../python/onnxruntime/demo_punc_offline.py | 9 + .../onnxruntime/funasr_onnx/__init__.py | 2 + .../onnxruntime/funasr_onnx/punc_bin.py | 133 +++++ .../funasr_onnx/utils/preprocessor.py | 470 ++++++++++++++++++ .../onnxruntime/funasr_onnx/utils/utils.py | 13 + 5 files changed, 627 insertions(+) create mode 100644 funasr/runtime/python/onnxruntime/demo_punc_offline.py create mode 100644 funasr/runtime/python/onnxruntime/funasr_onnx/utils/preprocessor.py diff --git a/funasr/runtime/python/onnxruntime/demo_punc_offline.py b/funasr/runtime/python/onnxruntime/demo_punc_offline.py new file mode 100644 index 000000000..056f73751 --- /dev/null +++ b/funasr/runtime/python/onnxruntime/demo_punc_offline.py @@ -0,0 +1,9 @@ +from funasr_onnx import TargetDelayTransformer + +model_dir = "/disk1/mengzhe.cmz/workspace/FunASR/funasr/export/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" +model = TargetDelayTransformer(model_dir) + +text_in = "我们都是木头人不会讲话不会动" + +result = model(text_in) +print(result) diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py index 475047903..1620a0b25 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py @@ -1,3 +1,5 @@ # -*- encoding: utf-8 -*- from .paraformer_bin import Paraformer from .vad_bin import Fsmn_vad +from .punc_bin import TargetDelayTransformer +#from .punc_bin import VadRealtimeTransformer diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index e69de29bb..64ced69be 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -0,0 +1,133 @@ +# -*- encoding: utf-8 -*- + +import os.path +from pathlib import Path +from typing import List, Union, Tuple +import numpy as np + +from .utils.utils import (ONNXRuntimeError, + OrtInferSession, get_logger, + read_yaml) +from .utils.preprocessor import CodeMixTokenizerCommonPreprocessor +from .utils.utils import split_to_mini_sentence +logging = get_logger() + + +class TargetDelayTransformer(): + def __init__(self, model_dir: Union[str, Path] = None, + batch_size: int = 1, + device_id: Union[str, int] = "-1", + quantize: bool = False, + intra_op_num_threads: int = 4 + ): + + if not Path(model_dir).exists(): + raise FileNotFoundError(f'{model_dir} does not exist.') + + model_file = os.path.join(model_dir, 'model.onnx') + if quantize: + model_file = os.path.join(model_dir, 'model_quant.onnx') + config_file = os.path.join(model_dir, 'punc.yaml') + config = read_yaml(config_file) + + self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) + self.batch_size = 1 + self.encoder_conf = config["encoder_conf"] + self.punc_list = config.punc_list + self.period = 0 + for i in range(len(self.punc_list)): + if self.punc_list[i] == ",": + self.punc_list[i] = "," + elif self.punc_list[i] == "?": + self.punc_list[i] = "?" + elif self.punc_list[i] == "。": + self.period = i + self.preprocessor = CodeMixTokenizerCommonPreprocessor( + train=False, + token_type=config.token_type, + token_list=config.token_list, + bpemodel=config.bpemodel, + text_cleaner=config.cleaner, + g2p_type=config.g2p, + text_name="text", + non_linguistic_symbols=config.non_linguistic_symbols, + ) + + def __call__(self, text: Union[list, str], split_size=20): + data = {"text": text} + result = self.preprocessor(data=data, uid="12938712838719") + split_text = self.preprocessor.pop_split_text_data(result) + mini_sentences = split_to_mini_sentence(split_text, split_size) + mini_sentences_id = split_to_mini_sentence(data["text"], split_size) + assert len(mini_sentences) == len(mini_sentences_id) + cache_sent = [] + cache_sent_id = [] + new_mini_sentence = "" + new_mini_sentence_punc = [] + cache_pop_trigger_limit = 200 + for mini_sentence_i in range(len(mini_sentences)): + mini_sentence = mini_sentences[mini_sentence_i] + mini_sentence_id = mini_sentences_id[mini_sentence_i] + mini_sentence = cache_sent + mini_sentence + mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) + data = { + "text": mini_sentence_id, + "text_lengths": len(mini_sentence_id), + } + try: + outputs = self.infer(data['text'], data['text_lengths']) + y = outputs[0] + _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) + punctuations = indices + assert punctuations.size()[0] == len(mini_sentence) + except ONNXRuntimeError: + logging.warning("error") + + # Search for the last Period/QuestionMark as cache + if mini_sentence_i < len(mini_sentences) - 1: + sentenceEnd = -1 + last_comma_index = -1 + for i in range(len(punctuations) - 2, 1, -1): + if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?": + sentenceEnd = i + break + if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",": + last_comma_index = i + + if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: + # The sentence it too long, cut off at a comma. + sentenceEnd = last_comma_index + punctuations[sentenceEnd] = self.period + cache_sent = mini_sentence[sentenceEnd + 1:] + cache_sent_id = mini_sentence_id[sentenceEnd + 1:] + mini_sentence = mini_sentence[0:sentenceEnd + 1] + punctuations = punctuations[0:sentenceEnd + 1] + + punctuations_np = punctuations.cpu().numpy() + new_mini_sentence_punc += [int(x) for x in punctuations_np] + words_with_punc = [] + for i in range(len(mini_sentence)): + if i > 0: + if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1: + mini_sentence[i] = " " + mini_sentence[i] + words_with_punc.append(mini_sentence[i]) + if self.punc_list[punctuations[i]] != "_": + words_with_punc.append(self.punc_list[punctuations[i]]) + new_mini_sentence += "".join(words_with_punc) + # Add Period for the end of the sentence + new_mini_sentence_out = new_mini_sentence + new_mini_sentence_punc_out = new_mini_sentence_punc + if mini_sentence_i == len(mini_sentences) - 1: + if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": + new_mini_sentence_out = new_mini_sentence[:-1] + "。" + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] + elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?": + new_mini_sentence_out = new_mini_sentence + "。" + new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] + return new_mini_sentence_out, new_mini_sentence_punc_out + + def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]: + + outputs = self.ort_infer(feats) + return outputs + diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/preprocessor.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/preprocessor.py new file mode 100644 index 000000000..4c9710371 --- /dev/null +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/preprocessor.py @@ -0,0 +1,470 @@ +import re +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import List +from typing import Union + +import numpy as np +import scipy.signal +import soundfile +from typeguard import check_argument_types +from typeguard import check_return_type + +from funasr.text.build_tokenizer import build_tokenizer +from funasr.text.cleaner import TextCleaner +from funasr.text.token_id_converter import TokenIDConverter + + +class AbsPreprocessor(ABC): + def __init__(self, train: bool): + self.train = train + + @abstractmethod + def __call__( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + raise NotImplementedError + + +def forward_segment(text, dic): + word_list = [] + i = 0 + while i < len(text): + longest_word = text[i] + for j in range(i + 1, len(text) + 1): + word = text[i:j] + if word in dic: + if len(word) > len(longest_word): + longest_word = word + word_list.append(longest_word) + i += len(longest_word) + return word_list + + +def seg_tokenize(txt, seg_dict): + out_txt = "" + for word in txt: + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + out_txt += "" + " " + return out_txt.strip().split() + +def seg_tokenize_wo_pattern(txt, seg_dict): + out_txt = "" + for word in txt: + if word in seg_dict: + out_txt += seg_dict[word] + " " + else: + out_txt += "" + " " + return out_txt.strip().split() + + +def framing( + x, + frame_length: int = 512, + frame_shift: int = 256, + centered: bool = True, + padded: bool = True, +): + if x.size == 0: + raise ValueError("Input array size is zero") + if frame_length < 1: + raise ValueError("frame_length must be a positive integer") + if frame_length > x.shape[-1]: + raise ValueError("frame_length is greater than input length") + if 0 >= frame_shift: + raise ValueError("frame_shift must be greater than 0") + + if centered: + pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [ + (frame_length // 2, frame_length // 2) + ] + x = np.pad(x, pad_shape, mode="constant", constant_values=0) + + if padded: + # Pad to integer number of windowed segments + # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep, + # with integer nseg + nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length + pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)] + x = np.pad(x, pad_shape, mode="constant", constant_values=0) + + # Created strided array of data segments + if frame_length == 1 and frame_length == frame_shift: + result = x[..., None] + else: + shape = x.shape[:-1] + ( + (x.shape[-1] - frame_length) // frame_shift + 1, + frame_length, + ) + strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1]) + result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return result + + +def detect_non_silence( + x: np.ndarray, + threshold: float = 0.01, + frame_length: int = 1024, + frame_shift: int = 512, + window: str = "boxcar", +) -> np.ndarray: + """Power based voice activity detection. + + Args: + x: (Channel, Time) + >>> x = np.random.randn(1000) + >>> detect = detect_non_silence(x) + >>> assert x.shape == detect.shape + >>> assert detect.dtype == np.bool + """ + if x.shape[-1] < frame_length: + return np.full(x.shape, fill_value=True, dtype=np.bool) + + if x.dtype.kind == "i": + x = x.astype(np.float64) + # framed_w: (C, T, F) + framed_w = framing( + x, + frame_length=frame_length, + frame_shift=frame_shift, + centered=False, + padded=True, + ) + framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype) + # power: (C, T) + power = (framed_w ** 2).mean(axis=-1) + # mean_power: (C, 1) + mean_power = np.mean(power, axis=-1, keepdims=True) + if np.all(mean_power == 0): + return np.full(x.shape, fill_value=True, dtype=np.bool) + # detect_frames: (C, T) + detect_frames = power / mean_power > threshold + # detects: (C, T, F) + detects = np.broadcast_to( + detect_frames[..., None], detect_frames.shape + (frame_shift,) + ) + # detects: (C, TF) + detects = detects.reshape(*detect_frames.shape[:-1], -1) + # detects: (C, TF) + return np.pad( + detects, + [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])], + mode="edge", + ) + + +class CommonPreprocessor(AbsPreprocessor): + def __init__( + self, + train: bool, + token_type: str = None, + token_list: Union[Path, str, Iterable[str]] = None, + bpemodel: Union[Path, str, Iterable[str]] = None, + text_cleaner: Collection[str] = None, + g2p_type: str = None, + unk_symbol: str = "", + space_symbol: str = "", + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + delimiter: str = None, + rir_scp: str = None, + rir_apply_prob: float = 1.0, + noise_scp: str = None, + noise_apply_prob: float = 1.0, + noise_db_range: str = "3_10", + speech_volume_normalize: float = None, + speech_name: str = "speech", + text_name: str = "text", + split_with_space: bool = False, + seg_dict_file: str = None, + ): + super().__init__(train) + self.train = train + self.speech_name = speech_name + self.text_name = text_name + self.speech_volume_normalize = speech_volume_normalize + self.rir_apply_prob = rir_apply_prob + self.noise_apply_prob = noise_apply_prob + self.split_with_space = split_with_space + self.seg_dict = None + if seg_dict_file is not None: + self.seg_dict = {} + with open(seg_dict_file) as f: + lines = f.readlines() + for line in lines: + s = line.strip().split() + key = s[0] + value = s[1:] + self.seg_dict[key] = " ".join(value) + + if token_type is not None: + if token_list is None: + raise ValueError("token_list is required if token_type is not None") + self.text_cleaner = TextCleaner(text_cleaner) + + self.tokenizer = build_tokenizer( + token_type=token_type, + bpemodel=bpemodel, + delimiter=delimiter, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + g2p_type=g2p_type, + ) + self.token_id_converter = TokenIDConverter( + token_list=token_list, + unk_symbol=unk_symbol, + ) + else: + self.text_cleaner = None + self.tokenizer = None + self.token_id_converter = None + + if train and rir_scp is not None: + self.rirs = [] + with open(rir_scp, "r", encoding="utf-8") as f: + for line in f: + sps = line.strip().split(None, 1) + if len(sps) == 1: + self.rirs.append(sps[0]) + else: + self.rirs.append(sps[1]) + else: + self.rirs = None + + if train and noise_scp is not None: + self.noises = [] + with open(noise_scp, "r", encoding="utf-8") as f: + for line in f: + sps = line.strip().split(None, 1) + if len(sps) == 1: + self.noises.append(sps[0]) + else: + self.noises.append(sps[1]) + sps = noise_db_range.split("_") + if len(sps) == 1: + self.noise_db_low, self.noise_db_high = float(sps[0]) + elif len(sps) == 2: + self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1]) + else: + raise ValueError( + "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]" + ) + else: + self.noises = None + + def _speech_process( + self, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, Union[str, np.ndarray]]: + assert check_argument_types() + if self.speech_name in data: + if self.train and (self.rirs is not None or self.noises is not None): + speech = data[self.speech_name] + nsamples = len(speech) + + # speech: (Nmic, Time) + if speech.ndim == 1: + speech = speech[None, :] + else: + speech = speech.T + # Calc power on non shlence region + power = (speech[detect_non_silence(speech)] ** 2).mean() + + # 1. Convolve RIR + if self.rirs is not None and self.rir_apply_prob >= np.random.random(): + rir_path = np.random.choice(self.rirs) + if rir_path is not None: + rir, _ = soundfile.read( + rir_path, dtype=np.float64, always_2d=True + ) + + # rir: (Nmic, Time) + rir = rir.T + + # speech: (Nmic, Time) + # Note that this operation doesn't change the signal length + speech = scipy.signal.convolve(speech, rir, mode="full")[ + :, : speech.shape[1] + ] + # Reverse mean power to the original power + power2 = (speech[detect_non_silence(speech)] ** 2).mean() + speech = np.sqrt(power / max(power2, 1e-10)) * speech + + # 2. Add Noise + if ( + self.noises is not None + and self.noise_apply_prob >= np.random.random() + ): + noise_path = np.random.choice(self.noises) + if noise_path is not None: + noise_db = np.random.uniform( + self.noise_db_low, self.noise_db_high + ) + with soundfile.SoundFile(noise_path) as f: + if f.frames == nsamples: + noise = f.read(dtype=np.float64, always_2d=True) + elif f.frames < nsamples: + offset = np.random.randint(0, nsamples - f.frames) + # noise: (Time, Nmic) + noise = f.read(dtype=np.float64, always_2d=True) + # Repeat noise + noise = np.pad( + noise, + [(offset, nsamples - f.frames - offset), (0, 0)], + mode="wrap", + ) + else: + offset = np.random.randint(0, f.frames - nsamples) + f.seek(offset) + # noise: (Time, Nmic) + noise = f.read( + nsamples, dtype=np.float64, always_2d=True + ) + if len(noise) != nsamples: + raise RuntimeError(f"Something wrong: {noise_path}") + # noise: (Nmic, Time) + noise = noise.T + + noise_power = (noise ** 2).mean() + scale = ( + 10 ** (-noise_db / 20) + * np.sqrt(power) + / np.sqrt(max(noise_power, 1e-10)) + ) + speech = speech + scale * noise + + speech = speech.T + ma = np.max(np.abs(speech)) + if ma > 1.0: + speech /= ma + data[self.speech_name] = speech + + if self.speech_volume_normalize is not None: + speech = data[self.speech_name] + ma = np.max(np.abs(speech)) + data[self.speech_name] = speech * self.speech_volume_normalize / ma + assert check_return_type(data) + return data + + def _text_process( + self, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + if self.text_name in data and self.tokenizer is not None: + text = data[self.text_name] + text = self.text_cleaner(text) + if self.split_with_space: + tokens = text.strip().split(" ") + if self.seg_dict is not None: + tokens = forward_segment("".join(tokens), self.seg_dict) + tokens = seg_tokenize(tokens, self.seg_dict) + else: + tokens = self.tokenizer.text2tokens(text) + text_ints = self.token_id_converter.tokens2ids(tokens) + data[self.text_name] = np.array(text_ints, dtype=np.int64) + assert check_return_type(data) + return data + + def __call__( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + assert check_argument_types() + + data = self._speech_process(data) + data = self._text_process(data) + return data + +class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor): + def __init__( + self, + train: bool, + token_type: str = None, + token_list: Union[Path, str, Iterable[str]] = None, + bpemodel: Union[Path, str, Iterable[str]] = None, + text_cleaner: Collection[str] = None, + g2p_type: str = None, + unk_symbol: str = "", + space_symbol: str = "", + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + delimiter: str = None, + rir_scp: str = None, + rir_apply_prob: float = 1.0, + noise_scp: str = None, + noise_apply_prob: float = 1.0, + noise_db_range: str = "3_10", + speech_volume_normalize: float = None, + speech_name: str = "speech", + text_name: str = "text", + split_text_name: str = "split_text", + split_with_space: bool = False, + seg_dict_file: str = None, + ): + super().__init__( + train=train, + # Force to use word. + token_type="word", + token_list=token_list, + bpemodel=bpemodel, + text_cleaner=text_cleaner, + g2p_type=g2p_type, + unk_symbol=unk_symbol, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + delimiter=delimiter, + speech_name=speech_name, + text_name=text_name, + rir_scp=rir_scp, + rir_apply_prob=rir_apply_prob, + noise_scp=noise_scp, + noise_apply_prob=noise_apply_prob, + noise_db_range=noise_db_range, + speech_volume_normalize=speech_volume_normalize, + split_with_space=split_with_space, + seg_dict_file=seg_dict_file, + ) + # The data field name for split text. + self.split_text_name = split_text_name + + @classmethod + def split_words(cls, text: str): + words = [] + segs = text.split() + for seg in segs: + # There is no space in seg. + current_word = "" + for c in seg: + if len(c.encode()) == 1: + # This is an ASCII char. + current_word += c + else: + # This is a Chinese char. + if len(current_word) > 0: + words.append(current_word) + current_word = "" + words.append(c) + if len(current_word) > 0: + words.append(current_word) + return words + + def __call__( + self, uid: str, data: Dict[str, Union[list, str, np.ndarray]] + ) -> Dict[str, Union[list, np.ndarray]]: + assert check_argument_types() + # Split words. + if isinstance(data[self.text_name], str): + split_text = self.split_words(data[self.text_name]) + else: + split_text = data[self.text_name] + data[self.text_name] = " ".join(split_text) + data = self._speech_process(data) + data = self._text_process(data) + data[self.split_text_name] = split_text + return data + + def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]): + result = data[self.split_text_name] + del data[self.split_text_name] + return result diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py index fccd5a095..c7e607691 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py @@ -215,6 +215,19 @@ class OrtInferSession(): if not model_path.is_file(): raise FileExistsError(f'{model_path} is not a file.') +def split_to_mini_sentence(words: list, word_limit: int = 20): + assert word_limit > 1 + if len(words) <= word_limit: + return [words] + sentences = [] + length = len(words) + sentence_len = length // word_limit + for i in range(sentence_len): + sentences.append(words[i * word_limit:(i + 1) * word_limit]) + if length % word_limit > 0: + sentences.append(words[sentence_len * word_limit:]) + return sentences + def read_yaml(yaml_path: Union[str, Path]) -> Dict: if not Path(yaml_path).exists():