This commit is contained in:
九耳 2023-02-05 10:48:12 +08:00
parent 3cad3b836e
commit 86d65112ab

View File

@ -538,3 +538,96 @@ class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
data[text_name] = np.array(text_ints, dtype=np.int64)
assert check_return_type(data)
return data
class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
def __init__(
self,
train: bool,
token_type: str = None,
token_list: Union[Path, str, Iterable[str]] = None,
bpemodel: Union[Path, str, Iterable[str]] = None,
text_cleaner: Collection[str] = None,
g2p_type: str = None,
unk_symbol: str = "<unk>",
space_symbol: str = "<space>",
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
delimiter: str = None,
rir_scp: str = None,
rir_apply_prob: float = 1.0,
noise_scp: str = None,
noise_apply_prob: float = 1.0,
noise_db_range: str = "3_10",
speech_volume_normalize: float = None,
speech_name: str = "speech",
text_name: str = "text",
split_text_name: str = "split_text",
split_with_space: bool = False,
seg_dict_file: str = None,
):
super().__init__(
train=train,
# Force to use word.
token_type="word",
token_list=token_list,
bpemodel=bpemodel,
text_cleaner=text_cleaner,
g2p_type=g2p_type,
unk_symbol=unk_symbol,
space_symbol=space_symbol,
non_linguistic_symbols=non_linguistic_symbols,
delimiter=delimiter,
speech_name=speech_name,
text_name=text_name,
rir_scp=rir_scp,
rir_apply_prob=rir_apply_prob,
noise_scp=noise_scp,
noise_apply_prob=noise_apply_prob,
noise_db_range=noise_db_range,
speech_volume_normalize=speech_volume_normalize,
split_with_space=split_with_space,
seg_dict_file=seg_dict_file,
)
# The data field name for split text.
self.split_text_name = split_text_name
@classmethod
def split_words(cls, text: str):
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def __call__(
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
) -> Dict[str, Union[list, np.ndarray]]:
assert check_argument_types()
# Split words.
if isinstance(data[self.text_name], str):
split_text = self.split_words(data[self.text_name])
else:
split_text = data[self.text_name]
data[self.text_name] = " ".join(split_text)
data = self._speech_process(data)
data = self._text_process(data)
data[self.split_text_name] = split_text
return data
def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
result = data[self.split_text_name]
del data[self.split_text_name]
return result