mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix
This commit is contained in:
parent
3cad3b836e
commit
86d65112ab
@ -538,3 +538,96 @@ class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
data[text_name] = np.array(text_ints, dtype=np.int64)
|
||||
assert check_return_type(data)
|
||||
return data
|
||||
|
||||
class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
|
||||
def __init__(
|
||||
self,
|
||||
train: bool,
|
||||
token_type: str = None,
|
||||
token_list: Union[Path, str, Iterable[str]] = None,
|
||||
bpemodel: Union[Path, str, Iterable[str]] = None,
|
||||
text_cleaner: Collection[str] = None,
|
||||
g2p_type: str = None,
|
||||
unk_symbol: str = "<unk>",
|
||||
space_symbol: str = "<space>",
|
||||
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
||||
delimiter: str = None,
|
||||
rir_scp: str = None,
|
||||
rir_apply_prob: float = 1.0,
|
||||
noise_scp: str = None,
|
||||
noise_apply_prob: float = 1.0,
|
||||
noise_db_range: str = "3_10",
|
||||
speech_volume_normalize: float = None,
|
||||
speech_name: str = "speech",
|
||||
text_name: str = "text",
|
||||
split_text_name: str = "split_text",
|
||||
split_with_space: bool = False,
|
||||
seg_dict_file: str = None,
|
||||
):
|
||||
super().__init__(
|
||||
train=train,
|
||||
# Force to use word.
|
||||
token_type="word",
|
||||
token_list=token_list,
|
||||
bpemodel=bpemodel,
|
||||
text_cleaner=text_cleaner,
|
||||
g2p_type=g2p_type,
|
||||
unk_symbol=unk_symbol,
|
||||
space_symbol=space_symbol,
|
||||
non_linguistic_symbols=non_linguistic_symbols,
|
||||
delimiter=delimiter,
|
||||
speech_name=speech_name,
|
||||
text_name=text_name,
|
||||
rir_scp=rir_scp,
|
||||
rir_apply_prob=rir_apply_prob,
|
||||
noise_scp=noise_scp,
|
||||
noise_apply_prob=noise_apply_prob,
|
||||
noise_db_range=noise_db_range,
|
||||
speech_volume_normalize=speech_volume_normalize,
|
||||
split_with_space=split_with_space,
|
||||
seg_dict_file=seg_dict_file,
|
||||
)
|
||||
# The data field name for split text.
|
||||
self.split_text_name = split_text_name
|
||||
|
||||
@classmethod
|
||||
def split_words(cls, text: str):
|
||||
words = []
|
||||
segs = text.split()
|
||||
for seg in segs:
|
||||
# There is no space in seg.
|
||||
current_word = ""
|
||||
for c in seg:
|
||||
if len(c.encode()) == 1:
|
||||
# This is an ASCII char.
|
||||
current_word += c
|
||||
else:
|
||||
# This is a Chinese char.
|
||||
if len(current_word) > 0:
|
||||
words.append(current_word)
|
||||
current_word = ""
|
||||
words.append(c)
|
||||
if len(current_word) > 0:
|
||||
words.append(current_word)
|
||||
return words
|
||||
|
||||
def __call__(
|
||||
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
|
||||
) -> Dict[str, Union[list, np.ndarray]]:
|
||||
assert check_argument_types()
|
||||
# Split words.
|
||||
if isinstance(data[self.text_name], str):
|
||||
split_text = self.split_words(data[self.text_name])
|
||||
else:
|
||||
split_text = data[self.text_name]
|
||||
data[self.text_name] = " ".join(split_text)
|
||||
data = self._speech_process(data)
|
||||
data = self._text_process(data)
|
||||
data[self.split_text_name] = split_text
|
||||
return data
|
||||
|
||||
def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
|
||||
result = data[self.split_text_name]
|
||||
del data[self.split_text_name]
|
||||
return result
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user