diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py index 75bee867c..10fbccba7 100644 --- a/funasr/datasets/preprocessor.py +++ b/funasr/datasets/preprocessor.py @@ -538,3 +538,96 @@ class MutliTokenizerCommonPreprocessor(CommonPreprocessor): data[text_name] = np.array(text_ints, dtype=np.int64) assert check_return_type(data) return data + +class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor): + def __init__( + self, + train: bool, + token_type: str = None, + token_list: Union[Path, str, Iterable[str]] = None, + bpemodel: Union[Path, str, Iterable[str]] = None, + text_cleaner: Collection[str] = None, + g2p_type: str = None, + unk_symbol: str = "", + space_symbol: str = "", + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + delimiter: str = None, + rir_scp: str = None, + rir_apply_prob: float = 1.0, + noise_scp: str = None, + noise_apply_prob: float = 1.0, + noise_db_range: str = "3_10", + speech_volume_normalize: float = None, + speech_name: str = "speech", + text_name: str = "text", + split_text_name: str = "split_text", + split_with_space: bool = False, + seg_dict_file: str = None, + ): + super().__init__( + train=train, + # Force to use word. + token_type="word", + token_list=token_list, + bpemodel=bpemodel, + text_cleaner=text_cleaner, + g2p_type=g2p_type, + unk_symbol=unk_symbol, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + delimiter=delimiter, + speech_name=speech_name, + text_name=text_name, + rir_scp=rir_scp, + rir_apply_prob=rir_apply_prob, + noise_scp=noise_scp, + noise_apply_prob=noise_apply_prob, + noise_db_range=noise_db_range, + speech_volume_normalize=speech_volume_normalize, + split_with_space=split_with_space, + seg_dict_file=seg_dict_file, + ) + # The data field name for split text. + self.split_text_name = split_text_name + + @classmethod + def split_words(cls, text: str): + words = [] + segs = text.split() + for seg in segs: + # There is no space in seg. + current_word = "" + for c in seg: + if len(c.encode()) == 1: + # This is an ASCII char. + current_word += c + else: + # This is a Chinese char. + if len(current_word) > 0: + words.append(current_word) + current_word = "" + words.append(c) + if len(current_word) > 0: + words.append(current_word) + return words + + def __call__( + self, uid: str, data: Dict[str, Union[list, str, np.ndarray]] + ) -> Dict[str, Union[list, np.ndarray]]: + assert check_argument_types() + # Split words. + if isinstance(data[self.text_name], str): + split_text = self.split_words(data[self.text_name]) + else: + split_text = data[self.text_name] + data[self.text_name] = " ".join(split_text) + data = self._speech_process(data) + data = self._text_process(data) + data[self.split_text_name] = split_text + return data + + def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]): + result = data[self.split_text_name] + del data[self.split_text_name] + return result +