mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
879 lines
32 KiB
Python
879 lines
32 KiB
Python
import re
|
|
from abc import ABC
|
|
from abc import abstractmethod
|
|
from pathlib import Path
|
|
from typing import Collection
|
|
from typing import Dict
|
|
from typing import Iterable
|
|
from typing import List
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import scipy.signal
|
|
import soundfile
|
|
import jieba
|
|
|
|
from funasr.text.build_tokenizer import build_tokenizer
|
|
from funasr.text.cleaner import TextCleaner
|
|
from funasr.text.token_id_converter import TokenIDConverter
|
|
|
|
|
|
class AbsPreprocessor(ABC):
|
|
def __init__(self, train: bool):
|
|
self.train = train
|
|
|
|
@abstractmethod
|
|
def __call__(
|
|
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
|
) -> Dict[str, np.ndarray]:
|
|
raise NotImplementedError
|
|
|
|
|
|
def forward_segment(text, dic):
|
|
word_list = []
|
|
i = 0
|
|
while i < len(text):
|
|
longest_word = text[i]
|
|
for j in range(i + 1, len(text) + 1):
|
|
word = text[i:j]
|
|
if word in dic:
|
|
if len(word) > len(longest_word):
|
|
longest_word = word
|
|
word_list.append(longest_word)
|
|
i += len(longest_word)
|
|
return word_list
|
|
|
|
def seg_tokenize(txt, seg_dict):
|
|
pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
|
|
out_txt = ""
|
|
for word in txt:
|
|
word = word.lower()
|
|
if word in seg_dict:
|
|
out_txt += seg_dict[word] + " "
|
|
else:
|
|
if pattern.match(word):
|
|
for char in word:
|
|
if char in seg_dict:
|
|
out_txt += seg_dict[char] + " "
|
|
else:
|
|
out_txt += "<unk>" + " "
|
|
else:
|
|
out_txt += "<unk>" + " "
|
|
return out_txt.strip().split()
|
|
|
|
def seg_tokenize_wo_pattern(txt, seg_dict):
|
|
out_txt = ""
|
|
for word in txt:
|
|
if word in seg_dict:
|
|
out_txt += seg_dict[word] + " "
|
|
else:
|
|
out_txt += "<unk>" + " "
|
|
return out_txt.strip().split()
|
|
|
|
|
|
def framing(
|
|
x,
|
|
frame_length: int = 512,
|
|
frame_shift: int = 256,
|
|
centered: bool = True,
|
|
padded: bool = True,
|
|
):
|
|
if x.size == 0:
|
|
raise ValueError("Input array size is zero")
|
|
if frame_length < 1:
|
|
raise ValueError("frame_length must be a positive integer")
|
|
if frame_length > x.shape[-1]:
|
|
raise ValueError("frame_length is greater than input length")
|
|
if 0 >= frame_shift:
|
|
raise ValueError("frame_shift must be greater than 0")
|
|
|
|
if centered:
|
|
pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
|
|
(frame_length // 2, frame_length // 2)
|
|
]
|
|
x = np.pad(x, pad_shape, mode="constant", constant_values=0)
|
|
|
|
if padded:
|
|
# Pad to integer number of windowed segments
|
|
# I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
|
|
# with integer nseg
|
|
nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
|
|
pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
|
|
x = np.pad(x, pad_shape, mode="constant", constant_values=0)
|
|
|
|
# Created strided array of data segments
|
|
if frame_length == 1 and frame_length == frame_shift:
|
|
result = x[..., None]
|
|
else:
|
|
shape = x.shape[:-1] + (
|
|
(x.shape[-1] - frame_length) // frame_shift + 1,
|
|
frame_length,
|
|
)
|
|
strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
|
|
result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
|
|
return result
|
|
|
|
|
|
def detect_non_silence(
|
|
x: np.ndarray,
|
|
threshold: float = 0.01,
|
|
frame_length: int = 1024,
|
|
frame_shift: int = 512,
|
|
window: str = "boxcar",
|
|
) -> np.ndarray:
|
|
"""Power based voice activity detection.
|
|
|
|
Args:
|
|
x: (Channel, Time)
|
|
>>> x = np.random.randn(1000)
|
|
>>> detect = detect_non_silence(x)
|
|
>>> assert x.shape == detect.shape
|
|
>>> assert detect.dtype == np.bool
|
|
"""
|
|
if x.shape[-1] < frame_length:
|
|
return np.full(x.shape, fill_value=True, dtype=np.bool)
|
|
|
|
if x.dtype.kind == "i":
|
|
x = x.astype(np.float64)
|
|
# framed_w: (C, T, F)
|
|
framed_w = framing(
|
|
x,
|
|
frame_length=frame_length,
|
|
frame_shift=frame_shift,
|
|
centered=False,
|
|
padded=True,
|
|
)
|
|
framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
|
|
# power: (C, T)
|
|
power = (framed_w ** 2).mean(axis=-1)
|
|
# mean_power: (C, 1)
|
|
mean_power = np.mean(power, axis=-1, keepdims=True)
|
|
if np.all(mean_power == 0):
|
|
return np.full(x.shape, fill_value=True, dtype=np.bool)
|
|
# detect_frames: (C, T)
|
|
detect_frames = power / mean_power > threshold
|
|
# detects: (C, T, F)
|
|
detects = np.broadcast_to(
|
|
detect_frames[..., None], detect_frames.shape + (frame_shift,)
|
|
)
|
|
# detects: (C, TF)
|
|
detects = detects.reshape(*detect_frames.shape[:-1], -1)
|
|
# detects: (C, TF)
|
|
return np.pad(
|
|
detects,
|
|
[(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
|
|
mode="edge",
|
|
)
|
|
|
|
|
|
class CommonPreprocessor(AbsPreprocessor):
|
|
def __init__(
|
|
self,
|
|
train: bool,
|
|
token_type: str = None,
|
|
token_list: Union[Path, str, Iterable[str]] = None,
|
|
bpemodel: Union[Path, str, Iterable[str]] = None,
|
|
text_cleaner: Collection[str] = None,
|
|
g2p_type: str = None,
|
|
unk_symbol: str = "<unk>",
|
|
space_symbol: str = "<space>",
|
|
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
|
delimiter: str = None,
|
|
rir_scp: str = None,
|
|
rir_apply_prob: float = 1.0,
|
|
noise_scp: str = None,
|
|
noise_apply_prob: float = 1.0,
|
|
noise_db_range: str = "3_10",
|
|
speech_volume_normalize: float = None,
|
|
speech_name: str = "speech",
|
|
text_name: str = "text",
|
|
split_with_space: bool = False,
|
|
seg_dict_file: str = None,
|
|
):
|
|
super().__init__(train)
|
|
self.train = train
|
|
self.speech_name = speech_name
|
|
self.text_name = text_name
|
|
self.speech_volume_normalize = speech_volume_normalize
|
|
self.rir_apply_prob = rir_apply_prob
|
|
self.noise_apply_prob = noise_apply_prob
|
|
self.split_with_space = split_with_space
|
|
self.seg_dict = None
|
|
if seg_dict_file is not None:
|
|
self.seg_dict = {}
|
|
with open(seg_dict_file) as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
s = line.strip().split()
|
|
key = s[0]
|
|
value = s[1:]
|
|
self.seg_dict[key] = " ".join(value)
|
|
|
|
if token_type is not None:
|
|
if token_list is None:
|
|
raise ValueError("token_list is required if token_type is not None")
|
|
self.text_cleaner = TextCleaner(text_cleaner)
|
|
|
|
self.tokenizer = build_tokenizer(
|
|
token_type=token_type,
|
|
bpemodel=bpemodel,
|
|
delimiter=delimiter,
|
|
space_symbol=space_symbol,
|
|
non_linguistic_symbols=non_linguistic_symbols,
|
|
g2p_type=g2p_type,
|
|
)
|
|
self.token_id_converter = TokenIDConverter(
|
|
token_list=token_list,
|
|
unk_symbol=unk_symbol,
|
|
)
|
|
else:
|
|
self.text_cleaner = None
|
|
self.tokenizer = None
|
|
self.token_id_converter = None
|
|
|
|
if train and rir_scp is not None:
|
|
self.rirs = []
|
|
with open(rir_scp, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
sps = line.strip().split(None, 1)
|
|
if len(sps) == 1:
|
|
self.rirs.append(sps[0])
|
|
else:
|
|
self.rirs.append(sps[1])
|
|
else:
|
|
self.rirs = None
|
|
|
|
if train and noise_scp is not None:
|
|
self.noises = []
|
|
with open(noise_scp, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
sps = line.strip().split(None, 1)
|
|
if len(sps) == 1:
|
|
self.noises.append(sps[0])
|
|
else:
|
|
self.noises.append(sps[1])
|
|
sps = noise_db_range.split("_")
|
|
if len(sps) == 1:
|
|
self.noise_db_low, self.noise_db_high = float(sps[0])
|
|
elif len(sps) == 2:
|
|
self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
|
|
else:
|
|
raise ValueError(
|
|
"Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
|
|
)
|
|
else:
|
|
self.noises = None
|
|
|
|
def _speech_process(
|
|
self, data: Dict[str, Union[str, np.ndarray]]
|
|
) -> Dict[str, Union[str, np.ndarray]]:
|
|
if self.speech_name in data:
|
|
if self.train and (self.rirs is not None or self.noises is not None):
|
|
speech = data[self.speech_name]
|
|
nsamples = len(speech)
|
|
|
|
# speech: (Nmic, Time)
|
|
if speech.ndim == 1:
|
|
speech = speech[None, :]
|
|
else:
|
|
speech = speech.T
|
|
# Calc power on non shlence region
|
|
power = (speech[detect_non_silence(speech)] ** 2).mean()
|
|
|
|
# 1. Convolve RIR
|
|
if self.rirs is not None and self.rir_apply_prob >= np.random.random():
|
|
rir_path = np.random.choice(self.rirs)
|
|
if rir_path is not None:
|
|
rir, _ = soundfile.read(
|
|
rir_path, dtype=np.float64, always_2d=True
|
|
)
|
|
|
|
# rir: (Nmic, Time)
|
|
rir = rir.T
|
|
|
|
# speech: (Nmic, Time)
|
|
# Note that this operation doesn't change the signal length
|
|
speech = scipy.signal.convolve(speech, rir, mode="full")[
|
|
:, : speech.shape[1]
|
|
]
|
|
# Reverse mean power to the original power
|
|
power2 = (speech[detect_non_silence(speech)] ** 2).mean()
|
|
speech = np.sqrt(power / max(power2, 1e-10)) * speech
|
|
|
|
# 2. Add Noise
|
|
if (
|
|
self.noises is not None
|
|
and self.noise_apply_prob >= np.random.random()
|
|
):
|
|
noise_path = np.random.choice(self.noises)
|
|
if noise_path is not None:
|
|
noise_db = np.random.uniform(
|
|
self.noise_db_low, self.noise_db_high
|
|
)
|
|
with soundfile.SoundFile(noise_path) as f:
|
|
if f.frames == nsamples:
|
|
noise = f.read(dtype=np.float64, always_2d=True)
|
|
elif f.frames < nsamples:
|
|
offset = np.random.randint(0, nsamples - f.frames)
|
|
# noise: (Time, Nmic)
|
|
noise = f.read(dtype=np.float64, always_2d=True)
|
|
# Repeat noise
|
|
noise = np.pad(
|
|
noise,
|
|
[(offset, nsamples - f.frames - offset), (0, 0)],
|
|
mode="wrap",
|
|
)
|
|
else:
|
|
offset = np.random.randint(0, f.frames - nsamples)
|
|
f.seek(offset)
|
|
# noise: (Time, Nmic)
|
|
noise = f.read(
|
|
nsamples, dtype=np.float64, always_2d=True
|
|
)
|
|
if len(noise) != nsamples:
|
|
raise RuntimeError(f"Something wrong: {noise_path}")
|
|
# noise: (Nmic, Time)
|
|
noise = noise.T
|
|
|
|
noise_power = (noise ** 2).mean()
|
|
scale = (
|
|
10 ** (-noise_db / 20)
|
|
* np.sqrt(power)
|
|
/ np.sqrt(max(noise_power, 1e-10))
|
|
)
|
|
speech = speech + scale * noise
|
|
|
|
speech = speech.T
|
|
ma = np.max(np.abs(speech))
|
|
if ma > 1.0:
|
|
speech /= ma
|
|
data[self.speech_name] = speech
|
|
|
|
if self.speech_volume_normalize is not None:
|
|
speech = data[self.speech_name]
|
|
ma = np.max(np.abs(speech))
|
|
data[self.speech_name] = speech * self.speech_volume_normalize / ma
|
|
return data
|
|
|
|
def _text_process(
|
|
self, data: Dict[str, Union[str, np.ndarray]]
|
|
) -> Dict[str, np.ndarray]:
|
|
if self.text_name in data and self.tokenizer is not None:
|
|
text = data[self.text_name]
|
|
text = self.text_cleaner(text)
|
|
if self.split_with_space:
|
|
tokens = text.strip().split(" ")
|
|
if self.seg_dict is not None:
|
|
tokens = seg_tokenize(tokens, self.seg_dict)
|
|
else:
|
|
tokens = self.tokenizer.text2tokens(text)
|
|
text_ints = self.token_id_converter.tokens2ids(tokens)
|
|
data[self.text_name] = np.array(text_ints, dtype=np.int64)
|
|
return data
|
|
|
|
def __call__(
|
|
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
|
) -> Dict[str, np.ndarray]:
|
|
|
|
data = self._speech_process(data)
|
|
data = self._text_process(data)
|
|
return data
|
|
|
|
## FIXME
|
|
class LMPreprocessor(CommonPreprocessor):
|
|
def __init__(
|
|
self,
|
|
train: bool,
|
|
token_type: str = None,
|
|
token_list: Union[Path, str, Iterable[str]] = None,
|
|
bpemodel: Union[Path, str, Iterable[str]] = None,
|
|
text_cleaner: Collection[str] = None,
|
|
g2p_type: str = None,
|
|
unk_symbol: str = "<unk>",
|
|
space_symbol: str = "<space>",
|
|
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
|
delimiter: str = None,
|
|
rir_scp: str = None,
|
|
rir_apply_prob: float = 1.0,
|
|
noise_scp: str = None,
|
|
noise_apply_prob: float = 1.0,
|
|
noise_db_range: str = "3_10",
|
|
speech_volume_normalize: float = None,
|
|
speech_name: str = "speech",
|
|
text_name: str = "text",
|
|
split_with_space: bool = False,
|
|
seg_dict_file: str = None,
|
|
):
|
|
super().__init__(train,
|
|
token_type,
|
|
token_list,
|
|
bpemodel,
|
|
text_cleaner,
|
|
g2p_type,
|
|
unk_symbol,
|
|
space_symbol,
|
|
non_linguistic_symbols,
|
|
delimiter,
|
|
rir_scp,
|
|
rir_apply_prob,
|
|
noise_scp,
|
|
noise_apply_prob,
|
|
noise_db_range,
|
|
speech_volume_normalize,
|
|
speech_name,
|
|
text_name,
|
|
split_with_space,
|
|
seg_dict_file,
|
|
)
|
|
|
|
def _text_process(
|
|
self, data: Dict[str, Union[str, np.ndarray]]
|
|
) -> Dict[str, np.ndarray]:
|
|
if self.text_name in data and self.tokenizer is not None:
|
|
text = data[self.text_name]
|
|
text = self.text_cleaner(text)
|
|
if self.split_with_space:
|
|
tokens = text.strip().split(" ")
|
|
if self.seg_dict is not None:
|
|
tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
|
|
else:
|
|
tokens = self.tokenizer.text2tokens(text)
|
|
text_ints = self.token_id_converter.tokens2ids(tokens)
|
|
data[self.text_name] = np.array(text_ints, dtype=np.int64)
|
|
return data
|
|
|
|
|
|
class CommonPreprocessor_multi(AbsPreprocessor):
|
|
def __init__(
|
|
self,
|
|
train: bool,
|
|
token_type: str = None,
|
|
token_list: Union[Path, str, Iterable[str]] = None,
|
|
bpemodel: Union[Path, str, Iterable[str]] = None,
|
|
text_cleaner: Collection[str] = None,
|
|
g2p_type: str = None,
|
|
unk_symbol: str = "<unk>",
|
|
space_symbol: str = "<space>",
|
|
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
|
delimiter: str = None,
|
|
speech_name: str = "speech",
|
|
text_name: List[str] = ["text"],
|
|
):
|
|
super().__init__(train)
|
|
self.train = train
|
|
self.speech_name = speech_name
|
|
self.text_name = text_name
|
|
|
|
if token_type is not None:
|
|
if token_list is None:
|
|
raise ValueError("token_list is required if token_type is not None")
|
|
self.text_cleaner = TextCleaner(text_cleaner)
|
|
|
|
self.tokenizer = build_tokenizer(
|
|
token_type=token_type,
|
|
bpemodel=bpemodel,
|
|
delimiter=delimiter,
|
|
space_symbol=space_symbol,
|
|
non_linguistic_symbols=non_linguistic_symbols,
|
|
g2p_type=g2p_type,
|
|
)
|
|
self.token_id_converter = TokenIDConverter(
|
|
token_list=token_list,
|
|
unk_symbol=unk_symbol,
|
|
)
|
|
else:
|
|
self.text_cleaner = None
|
|
self.tokenizer = None
|
|
self.token_id_converter = None
|
|
|
|
def _text_process(
|
|
self, data: Dict[str, Union[str, np.ndarray]]
|
|
) -> Dict[str, np.ndarray]:
|
|
for text_n in self.text_name:
|
|
if text_n in data and self.tokenizer is not None:
|
|
text = data[text_n]
|
|
text = self.text_cleaner(text)
|
|
tokens = self.tokenizer.text2tokens(text)
|
|
text_ints = self.token_id_converter.tokens2ids(tokens)
|
|
data[text_n] = np.array(text_ints, dtype=np.int64)
|
|
return data
|
|
|
|
def __call__(
|
|
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
|
|
) -> Dict[str, np.ndarray]:
|
|
|
|
if self.speech_name in data:
|
|
# Nothing now: candidates:
|
|
# - STFT
|
|
# - Fbank
|
|
# - CMVN
|
|
# - Data augmentation
|
|
pass
|
|
|
|
data = self._text_process(data)
|
|
return data
|
|
|
|
|
|
class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
|
|
def __init__(
|
|
self,
|
|
train: bool,
|
|
token_type: List[str] = [None],
|
|
token_list: List[Union[Path, str, Iterable[str]]] = [None],
|
|
bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
|
|
text_cleaner: Collection[str] = None,
|
|
g2p_type: str = None,
|
|
unk_symbol: str = "<unk>",
|
|
space_symbol: str = "<space>",
|
|
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
|
delimiter: str = None,
|
|
rir_scp: str = None,
|
|
rir_apply_prob: float = 1.0,
|
|
noise_scp: str = None,
|
|
noise_apply_prob: float = 1.0,
|
|
noise_db_range: str = "3_10",
|
|
speech_volume_normalize: float = None,
|
|
speech_name: str = "speech",
|
|
text_name: List[str] = ["text"],
|
|
):
|
|
# TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
|
|
super().__init__(
|
|
train=train,
|
|
token_type=token_type[0],
|
|
token_list=token_list[0],
|
|
bpemodel=bpemodel[0],
|
|
text_cleaner=text_cleaner,
|
|
g2p_type=g2p_type,
|
|
unk_symbol=unk_symbol,
|
|
space_symbol=space_symbol,
|
|
non_linguistic_symbols=non_linguistic_symbols,
|
|
delimiter=delimiter,
|
|
speech_name=speech_name,
|
|
text_name=text_name[0],
|
|
rir_scp=rir_scp,
|
|
rir_apply_prob=rir_apply_prob,
|
|
noise_scp=noise_scp,
|
|
noise_apply_prob=noise_apply_prob,
|
|
noise_db_range=noise_db_range,
|
|
speech_volume_normalize=speech_volume_normalize,
|
|
)
|
|
|
|
assert (
|
|
len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
|
|
), "token_type, token_list, bpemodel, or processing text_name mismatched"
|
|
self.num_tokenizer = len(token_type)
|
|
self.tokenizer = []
|
|
self.token_id_converter = []
|
|
|
|
for i in range(self.num_tokenizer):
|
|
if token_type[i] is not None:
|
|
if token_list[i] is None:
|
|
raise ValueError("token_list is required if token_type is not None")
|
|
|
|
self.tokenizer.append(
|
|
build_tokenizer(
|
|
token_type=token_type[i],
|
|
bpemodel=bpemodel[i],
|
|
delimiter=delimiter,
|
|
space_symbol=space_symbol,
|
|
non_linguistic_symbols=non_linguistic_symbols,
|
|
g2p_type=g2p_type,
|
|
)
|
|
)
|
|
self.token_id_converter.append(
|
|
TokenIDConverter(
|
|
token_list=token_list[i],
|
|
unk_symbol=unk_symbol,
|
|
)
|
|
)
|
|
else:
|
|
self.tokenizer.append(None)
|
|
self.token_id_converter.append(None)
|
|
|
|
self.text_cleaner = TextCleaner(text_cleaner)
|
|
self.text_name = text_name # override the text_name from CommonPreprocessor
|
|
|
|
def _text_process(
|
|
self, data: Dict[str, Union[str, np.ndarray]]
|
|
) -> Dict[str, np.ndarray]:
|
|
for i in range(self.num_tokenizer):
|
|
text_name = self.text_name[i]
|
|
if text_name in data and self.tokenizer[i] is not None:
|
|
text = data[text_name]
|
|
text = self.text_cleaner(text)
|
|
tokens = self.tokenizer[i].text2tokens(text)
|
|
text_ints = self.token_id_converter[i].tokens2ids(tokens)
|
|
data[text_name] = np.array(text_ints, dtype=np.int64)
|
|
return data
|
|
|
|
class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
|
|
def __init__(
|
|
self,
|
|
train: bool,
|
|
token_type: str = None,
|
|
token_list: Union[Path, str, Iterable[str]] = None,
|
|
bpemodel: Union[Path, str, Iterable[str]] = None,
|
|
text_cleaner: Collection[str] = None,
|
|
g2p_type: str = None,
|
|
unk_symbol: str = "<unk>",
|
|
space_symbol: str = "<space>",
|
|
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
|
delimiter: str = None,
|
|
rir_scp: str = None,
|
|
rir_apply_prob: float = 1.0,
|
|
noise_scp: str = None,
|
|
noise_apply_prob: float = 1.0,
|
|
noise_db_range: str = "3_10",
|
|
speech_volume_normalize: float = None,
|
|
speech_name: str = "speech",
|
|
text_name: str = "text",
|
|
split_text_name: str = "split_text",
|
|
split_with_space: bool = False,
|
|
seg_jieba: bool = False,
|
|
seg_dict_file: str = None,
|
|
):
|
|
super().__init__(
|
|
train=train,
|
|
# Force to use word.
|
|
token_type="word",
|
|
token_list=token_list,
|
|
bpemodel=bpemodel,
|
|
text_cleaner=text_cleaner,
|
|
g2p_type=g2p_type,
|
|
unk_symbol=unk_symbol,
|
|
space_symbol=space_symbol,
|
|
non_linguistic_symbols=non_linguistic_symbols,
|
|
delimiter=delimiter,
|
|
speech_name=speech_name,
|
|
text_name=text_name,
|
|
rir_scp=rir_scp,
|
|
rir_apply_prob=rir_apply_prob,
|
|
noise_scp=noise_scp,
|
|
noise_apply_prob=noise_apply_prob,
|
|
noise_db_range=noise_db_range,
|
|
speech_volume_normalize=speech_volume_normalize,
|
|
split_with_space=split_with_space,
|
|
seg_dict_file=seg_dict_file,
|
|
)
|
|
# The data field name for split text.
|
|
self.split_text_name = split_text_name
|
|
self.seg_jieba = seg_jieba
|
|
if self.seg_jieba:
|
|
jieba.load_userdict(seg_dict_file)
|
|
|
|
@classmethod
|
|
def split_words(cls, text: str):
|
|
words = []
|
|
segs = text.split()
|
|
for seg in segs:
|
|
# There is no space in seg.
|
|
current_word = ""
|
|
for c in seg:
|
|
if len(c.encode()) == 1:
|
|
# This is an ASCII char.
|
|
current_word += c
|
|
else:
|
|
# This is a Chinese char.
|
|
if len(current_word) > 0:
|
|
words.append(current_word)
|
|
current_word = ""
|
|
words.append(c)
|
|
if len(current_word) > 0:
|
|
words.append(current_word)
|
|
return words
|
|
|
|
@classmethod
|
|
def isEnglish(cls, text:str):
|
|
if re.search('^[a-zA-Z\']+$', text):
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
@classmethod
|
|
def join_chinese_and_english(cls, input_list):
|
|
line = ''
|
|
for token in input_list:
|
|
if cls.isEnglish(token):
|
|
line = line + ' ' + token
|
|
else:
|
|
line = line + token
|
|
|
|
line = line.strip()
|
|
return line
|
|
|
|
@classmethod
|
|
def split_words_jieba(cls, text: str):
|
|
input_list = text.split()
|
|
token_list_all = []
|
|
langauge_list = []
|
|
token_list_tmp = []
|
|
language_flag = None
|
|
for token in input_list:
|
|
if cls.isEnglish(token) and language_flag == 'Chinese':
|
|
token_list_all.append(token_list_tmp)
|
|
langauge_list.append('Chinese')
|
|
token_list_tmp = []
|
|
elif not cls.isEnglish(token) and language_flag == 'English':
|
|
token_list_all.append(token_list_tmp)
|
|
langauge_list.append('English')
|
|
token_list_tmp = []
|
|
|
|
token_list_tmp.append(token)
|
|
|
|
if cls.isEnglish(token):
|
|
language_flag = 'English'
|
|
else:
|
|
language_flag = 'Chinese'
|
|
|
|
if token_list_tmp:
|
|
token_list_all.append(token_list_tmp)
|
|
langauge_list.append(language_flag)
|
|
|
|
result_list = []
|
|
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
|
|
if language_flag == 'English':
|
|
result_list.extend(token_list_tmp)
|
|
else:
|
|
seg_list = jieba.cut(cls.join_chinese_and_english(token_list_tmp), HMM=False)
|
|
result_list.extend(seg_list)
|
|
|
|
return result_list
|
|
|
|
def __call__(
|
|
self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
|
|
) -> Dict[str, Union[list, np.ndarray]]:
|
|
# Split words.
|
|
if isinstance(data[self.text_name], str):
|
|
if self.seg_jieba:
|
|
# jieba.load_userdict(seg_dict_file)
|
|
split_text = self.split_words_jieba(data[self.text_name])
|
|
else:
|
|
split_text = self.split_words(data[self.text_name])
|
|
else:
|
|
split_text = data[self.text_name]
|
|
data[self.text_name] = " ".join(split_text)
|
|
data = self._speech_process(data)
|
|
data = self._text_process(data)
|
|
data[self.split_text_name] = split_text
|
|
return data
|
|
|
|
def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
|
|
result = data[self.split_text_name]
|
|
del data[self.split_text_name]
|
|
return result
|
|
|
|
class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
|
|
def __init__(
|
|
self,
|
|
train: bool,
|
|
token_type: List[str] = [None],
|
|
token_list: List[Union[Path, str, Iterable[str]]] = [None],
|
|
bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
|
|
text_cleaner: Collection[str] = None,
|
|
g2p_type: str = None,
|
|
unk_symbol: str = "<unk>",
|
|
space_symbol: str = "<space>",
|
|
non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
|
|
delimiter: str = None,
|
|
rir_scp: str = None,
|
|
rir_apply_prob: float = 1.0,
|
|
noise_scp: str = None,
|
|
noise_apply_prob: float = 1.0,
|
|
noise_db_range: str = "3_10",
|
|
speech_volume_normalize: float = None,
|
|
speech_name: str = "speech",
|
|
text_name: List[str] = ["text"],
|
|
vad_name: str = "vad_indexes",
|
|
):
|
|
# TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
|
|
super().__init__(
|
|
train=train,
|
|
token_type=token_type[0],
|
|
token_list=token_list[0],
|
|
bpemodel=bpemodel[0],
|
|
text_cleaner=text_cleaner,
|
|
g2p_type=g2p_type,
|
|
unk_symbol=unk_symbol,
|
|
space_symbol=space_symbol,
|
|
non_linguistic_symbols=non_linguistic_symbols,
|
|
delimiter=delimiter,
|
|
speech_name=speech_name,
|
|
text_name=text_name[0],
|
|
rir_scp=rir_scp,
|
|
rir_apply_prob=rir_apply_prob,
|
|
noise_scp=noise_scp,
|
|
noise_apply_prob=noise_apply_prob,
|
|
noise_db_range=noise_db_range,
|
|
speech_volume_normalize=speech_volume_normalize,
|
|
)
|
|
|
|
assert (
|
|
len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
|
|
), "token_type, token_list, bpemodel, or processing text_name mismatched"
|
|
self.num_tokenizer = len(token_type)
|
|
self.tokenizer = []
|
|
self.token_id_converter = []
|
|
|
|
for i in range(self.num_tokenizer):
|
|
if token_type[i] is not None:
|
|
if token_list[i] is None:
|
|
raise ValueError("token_list is required if token_type is not None")
|
|
|
|
self.tokenizer.append(
|
|
build_tokenizer(
|
|
token_type=token_type[i],
|
|
bpemodel=bpemodel[i],
|
|
delimiter=delimiter,
|
|
space_symbol=space_symbol,
|
|
non_linguistic_symbols=non_linguistic_symbols,
|
|
g2p_type=g2p_type,
|
|
)
|
|
)
|
|
self.token_id_converter.append(
|
|
TokenIDConverter(
|
|
token_list=token_list[i],
|
|
unk_symbol=unk_symbol,
|
|
)
|
|
)
|
|
else:
|
|
self.tokenizer.append(None)
|
|
self.token_id_converter.append(None)
|
|
|
|
self.text_cleaner = TextCleaner(text_cleaner)
|
|
self.text_name = text_name # override the text_name from CommonPreprocessor
|
|
self.vad_name = vad_name
|
|
|
|
def _text_process(
|
|
self, data: Dict[str, Union[str, np.ndarray]]
|
|
) -> Dict[str, np.ndarray]:
|
|
for i in range(self.num_tokenizer):
|
|
text_name = self.text_name[i]
|
|
if text_name in data and self.tokenizer[i] is not None:
|
|
text = data[text_name]
|
|
text = self.text_cleaner(text)
|
|
tokens = self.tokenizer[i].text2tokens(text)
|
|
if "vad:" in tokens[-1]:
|
|
vad = tokens[-1][4:]
|
|
tokens = tokens[:-1]
|
|
if len(vad) == 0:
|
|
vad = -1
|
|
else:
|
|
vad = int(vad)
|
|
data[self.vad_name] = np.array([vad], dtype=np.int64)
|
|
text_ints = self.token_id_converter[i].tokens2ids(tokens)
|
|
data[text_name] = np.array(text_ints, dtype=np.int64)
|
|
return data
|
|
|
|
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
|
assert word_limit > 1
|
|
if len(words) <= word_limit:
|
|
return [words]
|
|
sentences = []
|
|
length = len(words)
|
|
sentence_len = length // word_limit
|
|
for i in range(sentence_len):
|
|
sentences.append(words[i * word_limit:(i + 1) * word_limit])
|
|
if length % word_limit > 0:
|
|
sentences.append(words[sentence_len * word_limit:])
|
|
return sentences
|