This commit is contained in:
游雁 2024-06-25 20:43:08 +08:00
parent add1ac00f7
commit 1cbd2015f0
19 changed files with 21 additions and 1177 deletions

View File

@ -171,7 +171,8 @@ class AutoModel:
self.spk_kwargs = spk_kwargs
self.model_path = kwargs.get("model_path")
def build_model(self, **kwargs):
@staticmethod
def build_model(**kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))

View File

@ -1,9 +0,0 @@
from abc import ABC
from abc import abstractmethod
from typing import Iterator
class AbsIterFactory(ABC):
@abstractmethod
def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
raise NotImplementedError

View File

@ -1,109 +0,0 @@
import logging
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Union
import sentencepiece as spm
from torch.utils.data import DataLoader
from funasr.datasets.large_datasets.dataset import Dataset
from funasr.datasets.large_datasets.abs_iter_factory import AbsIterFactory
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
from funasr.register import tables
def read_symbol_table(symbol_table_file):
if isinstance(symbol_table_file, str):
symbol_table = {}
with open(symbol_table_file, "r", encoding="utf8") as fin:
for i, line in enumerate(fin):
char = line.strip()
symbol_table[char] = i
else:
assert isinstance(symbol_table_file, list)
symbol_table = {}
for i, char in enumerate(symbol_table_file):
symbol_table[char] = i
return symbol_table
def load_seg_dict(seg_dict_file):
seg_dict = {}
assert isinstance(seg_dict_file, str)
with open(seg_dict_file, "r", encoding="utf8") as f:
lines = f.readlines()
for line in lines:
s = line.strip().split()
key = s[0]
value = s[1:]
seg_dict[key] = " ".join(value)
return seg_dict
class SentencepiecesTokenizer(AbsTokenizer):
def __init__(self, model: Union[Path, str]):
self.model = str(model)
self.sp = None
def __repr__(self):
return f'{self.__class__.__name__}(model="{self.model}")'
def _build_sentence_piece_processor(self):
if self.sp is None:
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.model)
def text2tokens(self, line: str) -> List[str]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsPieces(line)
def tokens2text(self, tokens: Iterable[str]) -> str:
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
@tables.register("dataset_classes", "LargeDataset")
class LargeDataLoader(AbsIterFactory):
def __init__(self, args, mode="train"):
symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
if hasattr(args, "token_list") and args.token_list is not None:
symbol_table = read_symbol_table(args.token_list)
if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
seg_dict = load_seg_dict(args.seg_dict_file)
if hasattr(args, "punc_list") and args.punc_list is not None:
punc_dict = read_symbol_table(args.punc_list)
if hasattr(args, "bpemodel") and args.bpemodel is not None:
bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel)
self.dataset_conf = args.dataset_conf
if "frontend_conf" not in args:
self.frontend_conf = None
else:
self.frontend_conf = args.frontend_conf
self.speed_perturb = args.speed_perturb if hasattr(args, "speed_perturb") else None
logging.info("dataloader config: {}".format(self.dataset_conf))
batch_mode = self.dataset_conf.get("batch_mode", "padding")
data_list = args.train_data_file if mode == "train" else args.valid_data_file
self.dataset = Dataset(
data_list,
symbol_table,
seg_dict,
punc_dict,
bpe_tokenizer,
self.dataset_conf,
self.frontend_conf,
speed_perturb=self.speed_perturb if mode == "train" else None,
mode=mode,
batch_mode=batch_mode,
)
def build_iter(self, epoch, shuffle=True):
self.dataset.set_epoch(epoch)
data_loader = DataLoader(
self.dataset,
batch_size=None,
pin_memory=True,
num_workers=self.dataset_conf.get("num_workers", 8),
)
return data_loader

View File

@ -1,194 +0,0 @@
from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import torch
from funasr.models.transformer.utils.nets_utils import pad_list, pad_list_all_dim
class CommonCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
max_sample_size=None,
):
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.max_sample_size = max_sample_size
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})"
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
return common_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
)
def common_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor."""
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
return output
class DiarCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
max_sample_size=None,
):
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.max_sample_size = max_sample_size
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})"
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
return diar_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
)
def diar_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor."""
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
tensor = pad_list_all_dim(tensor_list, pad_value)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
return output
def crop_to_max_size(feature, target_size):
size = len(feature)
diff = size - target_size
if diff <= 0:
return feature
start = np.random.randint(0, diff + 1)
end = size - diff + start
return feature[start:end]
def clipping_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
max_sample_size=None,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
# mainly for pre-training
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
sizes = [len(s) for s in tensor_list]
if max_sample_size is None:
target_size = min(sizes)
else:
target_size = min(min(sizes), max_sample_size)
tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
for i, (source, size) in enumerate(zip(tensor_list, sizes)):
diff = size - target_size
if diff == 0:
tensor[i] = source
else:
tensor[i] = crop_to_max_size(source, target_size)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
return output

View File

@ -1,213 +0,0 @@
import random
from itertools import count
from functools import partial
from torch.utils.data import IterableDataset
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
tiebreaker = count()
def _default_len_fn(token):
return len(token), next(tiebreaker)
def _token_len_fn(token, len_fn):
return len_fn(token), next(tiebreaker), token
class MaxTokenBucketizerIterDataPipe(IterableDataset):
def __init__(
self,
datapipe,
batch_size=8000,
len_fn=_default_len_fn,
buffer_size=10240,
sort_size=500,
batch_mode="padding",
):
assert batch_size > 0, "Batch size is required to be larger than 0!"
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
assert sort_size > 0, "Sort size is required to be larger than 0!"
datapipe = MapperIterDataPipe(datapipe, fn=partial(_token_len_fn, len_fn=len_fn))
self.datapipe = datapipe
self.batch_size = batch_size
self.buffer_size = buffer_size
self.sort_size = sort_size
self.batch_mode = batch_mode
def set_epoch(self, epoch):
self.datapipe.set_epoch(epoch)
def __iter__(self):
buffer = []
batch = []
bucket = []
max_lengths = 0
min_lengths = 999999
batch_lengths = 0
if self.batch_mode == "clipping":
assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
if len(buffer) == self.buffer_size:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length < min_lengths:
min_lengths = length
batch_lengths = min_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
min_lengths = length
batch.append(token)
bucket = []
buffer = []
if buffer:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length < min_lengths:
min_lengths = length
batch_lengths = min_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
min_lengths = length
batch.append(token)
bucket = []
buffer = []
if bucket:
bucket.sort()
for x in bucket:
length, _, token = x
if length < min_lengths:
min_lengths = length
batch_lengths = min_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
min_lengths = length
batch.append(token)
bucket = []
if batch:
yield batch
else:
if self.buffer_size == -1:
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
buffer.sort()
for sample in buffer:
length, _, token = sample
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
bucket.append(batch)
batch = []
max_lengths = length
batch.append(token)
random.shuffle(bucket)
if bucket:
for batch_sample in bucket:
yield batch_sample
if batch:
yield batch
elif self.buffer_size == 0:
for d in self.datapipe:
if d[0] > self.batch_size:
continue
length, _, token = d
if length > self.batch_size:
continue
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
if batch:
yield batch
else:
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
if len(buffer) == self.buffer_size:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
bucket = []
buffer = []
if buffer:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
if len(bucket) == self.sort_size:
bucket.sort()
for x in bucket:
length, _, token = x
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
bucket = []
buffer = []
if bucket:
bucket.sort()
for x in bucket:
length, _, token = x
if length > max_lengths:
max_lengths = length
batch_lengths = max_lengths * (len(batch) + 1)
if batch_lengths > self.batch_size:
yield batch
batch = []
max_lengths = length
batch.append(token)
bucket = []
if batch:
yield batch

View File

@ -1,23 +0,0 @@
from torch.utils.data import IterableDataset
def default_fn(data):
return data
class FilterIterDataPipe(IterableDataset):
def __init__(self, datapipe, fn=default_fn):
self.datapipe = datapipe
self.fn = fn
def set_epoch(self, epoch):
self.datapipe.set_epoch(epoch)
def __iter__(self):
assert callable(self.fn)
for data in self.datapipe:
if self.fn(data):
yield data
else:
continue

View File

@ -1,20 +0,0 @@
from torch.utils.data import IterableDataset
def default_fn(data):
return data
class MapperIterDataPipe(IterableDataset):
def __init__(self, datapipe, fn=default_fn):
self.datapipe = datapipe
self.fn = fn
def set_epoch(self, epoch):
self.datapipe.set_epoch(epoch)
def __iter__(self):
assert callable(self.fn)
for data in self.datapipe:
yield self.fn(data)

View File

@ -1,299 +0,0 @@
import logging
import os
import random
from functools import partial
import torch
import torch.distributed as dist
import torchaudio
import numpy as np
# import librosa
import librosa
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.filter import filter
from funasr.datasets.large_datasets.utils.padding import padding
from funasr.datasets.large_datasets.utils.tokenize import tokenize
def read_lists(list_file):
lists = []
with open(list_file, "r", encoding="utf8") as fin:
for line in fin:
parts = line.strip()
lists.append(parts)
return lists
class AudioDataset(IterableDataset):
def __init__(
self,
scp_lists,
data_names,
data_types,
frontend_conf=None,
shuffle=True,
speed_perturb=None,
mode="train",
):
self.scp_lists = scp_lists
self.data_names = data_names
self.data_types = data_types
self.frontend_conf = frontend_conf
self.shuffle = shuffle
self.mode = mode
self.epoch = -1
self.rank = 0
self.world_size = 1
self.worker_id = 0
self.num_workers = 1
self.speed_perturb = speed_perturb
if self.speed_perturb is not None:
logging.info("Using speed_perturb: {}".format(speed_perturb))
def set_epoch(self, epoch):
self.epoch = epoch
def get_rank_data_list(self, data_index):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
if self.mode == "train":
if self.shuffle:
random.seed(self.epoch)
random.shuffle(data_index)
return data_index[self.rank :: self.world_size]
return data_index
def get_worker_data_list(self, rank_data_index):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return rank_data_index[self.worker_id :: self.num_workers]
def close_reader(self, reader_list):
for reader in reader_list:
reader.close()
def __iter__(self):
data_index = list(range(len(self.scp_lists)))
rank_data_index = self.get_rank_data_list(data_index)
worker_data_index = self.get_worker_data_list(rank_data_index)
for index in worker_data_index:
data = dict(scp=self.scp_lists[index])
assert "scp" in data
scp = data["scp"]
data_file_list = scp.strip().split()
data_name_list = self.data_names.split(",")
data_type_list = self.data_types.split(",")
for file in data_file_list:
assert os.path.exists(file), "{} not exists".format(file)
assert (
len(data_file_list) == len(data_name_list) == len(data_type_list)
), "The item number of data, data_names, data_types must be the same "
reader_list = []
for data_file, data_type in zip(data_file_list, data_type_list):
if data_type == "kaldi_ark":
ark_reader = ReadHelper("ark:{}".format(data_file))
reader_list.append(ark_reader)
elif data_type == "text" or data_type == "sound" or data_type == "text_hotword":
text_reader = open(data_file, "r", encoding="utf-8")
reader_list.append(text_reader)
elif data_type == "none":
continue
else:
raise TypeError("Data type {} is not supported".format(data_type))
for items in zip(*reader_list):
sample_dict = {}
for item, (data_name, data_type) in zip(items, zip(data_name_list, data_type_list)):
if data_type == "kaldi_ark":
key, mat = item
sample_dict[data_name] = mat
if data_name == "speech":
sample_dict["key"] = key
elif data_type == "sound":
key, path = item.strip().split()
try:
waveform, sampling_rate = torchaudio.load(path)
except:
# waveform, sampling_rate = librosa.load(path, dtype='float32')
waveform, sampling_rate = librosa.load(path, dtype="float32")
if waveform.ndim == 2:
waveform = waveform[:, 0]
waveform = np.expand_dims(waveform, axis=0)
waveform = torch.tensor(waveform)
if self.frontend_conf is not None:
if sampling_rate != self.frontend_conf["fs"]:
waveform = torchaudio.transforms.Resample(
orig_freq=sampling_rate, new_freq=self.frontend_conf["fs"]
)(waveform)
sampling_rate = self.frontend_conf["fs"]
waveform = waveform.numpy()
mat = waveform[0]
if self.speed_perturb is not None:
speed = random.choice(self.speed_perturb)
if speed != 1.0:
mat, _ = torchaudio.sox_effects.apply_effects_tensor(
torch.tensor(mat).view(1, -1),
sampling_rate,
[["speed", str(speed)], ["rate", str(sampling_rate)]],
)
mat = mat.view(-1).numpy()
sample_dict[data_name] = mat
sample_dict["sampling_rate"] = sampling_rate
if data_name == "speech":
sample_dict["key"] = key
elif data_type == "text_hotword":
text = item
segs = text.strip().split()
sample_dict[data_name] = segs[1:]
if "key" not in sample_dict:
sample_dict["key"] = segs[0]
sample_dict["hw_tag"] = 1
elif data_type == "text_nospace":
text = item
segs = text.strip().split(maxsplit=1)
sample_dict[data_name] = [x for x in segs[1]]
if "key" not in sample_dict:
sample_dict["key"] = segs[0]
else:
text = item
segs = text.strip().split()
sample_dict[data_name] = segs[1:]
if "key" not in sample_dict:
sample_dict["key"] = segs[0]
yield sample_dict
self.close_reader(reader_list)
def len_fn_example(data):
return 1
def len_fn_token(data):
assert "speech" in data
if "sampling_rate" in data:
return (data["speech"].shape[0] / data["sampling_rate"]) * 1000.0
else:
return data["speech"].shape[0]
def Dataset(
data_list_file,
dict,
seg_dict,
punc_dict,
bpe_tokenizer,
conf,
frontend_conf,
speed_perturb=None,
mode="train",
batch_mode="padding",
):
scp_lists = read_lists(data_list_file)
shuffle = conf.get("shuffle", True)
data_names = conf.get("data_names", "speech,text")
data_types = conf.get("data_types", "kaldi_ark,text")
pre_hwfile = conf.get("pre_hwlist", None)
# pre_prob = conf.get("pre_prob", 0) # unused yet
if pre_hwfile is not None:
pre_hwlist = []
with open(pre_hwfile, "r", encoding="utf-8") as fin:
for line in fin.readlines():
pre_hwlist.append(line.strip())
else:
pre_hwlist = None
hw_config = {
"sample_rate": conf.get("sample_rate", 0.6),
"double_rate": conf.get("double_rate", 0.1),
"hotword_min_length": conf.get("hotword_min_length", 2),
"hotword_max_length": conf.get("hotword_max_length", 8),
"pre_prob": conf.get("pre_prob", 0.0),
"pre_hwlist": pre_hwlist,
}
dataset = AudioDataset(
scp_lists,
data_names,
data_types,
frontend_conf=frontend_conf,
shuffle=shuffle,
speed_perturb=speed_perturb,
mode=mode,
)
if "text" in data_names:
vocab = {
"vocab": dict,
"seg_dict": seg_dict,
"punc_dict": punc_dict,
"bpe_tokenizer": bpe_tokenizer,
"hw_config": hw_config,
}
tokenize_fn = partial(tokenize, **vocab)
dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
filter_conf = conf.get("filter_conf", {})
filter_fn = partial(filter, **filter_conf)
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
if shuffle:
buffer_conf = conf.get("shuffle_conf", {})
buffer_size = buffer_conf["shuffle_size"]
sort_size = buffer_conf["sort_size"]
else:
buffer_size = 0
sort_size = 1
batch_conf = conf.get("batch_conf", {})
batch_size = batch_conf["batch_size"]
batch_type = batch_conf["batch_type"]
assert batch_type in ["example", "token"]
if batch_type == "example":
len_fn = len_fn_example
else:
len_fn = len_fn_token
dataset = MaxTokenBucketizerIterDataPipe(
dataset,
batch_size=batch_size,
len_fn=len_fn,
buffer_size=buffer_size,
sort_size=sort_size,
batch_mode=batch_mode,
)
int_pad_value = conf.get("int_pad_value", -1)
float_pad_value = conf.get("float_pad_value", 0.0)
padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
padding_fn = partial(padding, **padding_conf)
dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
return dataset

View File

@ -1,44 +0,0 @@
import numpy as np
import torch
from funasr.datasets.large_datasets.collate_fn import crop_to_max_size
def clipping(data):
assert isinstance(data, list)
assert "key" in data[0]
keys = [x["key"] for x in data]
batch = {}
data_names = data[0].keys()
for data_name in data_names:
if data_name == "key":
continue
else:
if data[0][data_name].dtype.kind == "i":
tensor_type = torch.int64
else:
tensor_type = torch.float32
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
length_clip = min(tensor_lengths)
tensor_clip = tensor_list[0].new_zeros(
len(tensor_list), length_clip, tensor_list[0].shape[1]
)
for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
diff = length - length_clip
assert diff >= 0
if diff == 0:
tensor_clip[i] = tensor
else:
tensor_clip[i] = crop_to_max_size(tensor, length_clip)
batch[data_name] = tensor_clip
batch[data_name + "_lengths"] = torch.tensor(
[tensor.shape[0] for tensor in tensor_clip], dtype=torch.long
)
return keys, batch

View File

@ -1,27 +0,0 @@
#!/usr/bin/env python
def filter(
data, speech_length_min=100, speech_length_max=15000, token_length_min=0, token_length_max=200
):
assert "speech" in data or "text" in data
if "speech" in data and "text" in data:
if "sampling_rate" in data:
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.0
else:
speech_length = data["speech"].shape[0]
num_tokens = len(data["text"])
return (
speech_length_min < speech_length < speech_length_max
and token_length_min < num_tokens < token_length_max
)
elif "speech" in data:
if "sampling_rate" in data:
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.0
else:
speech_length = data["speech"].shape[0]
return speech_length_min < speech_length < speech_length_max
else:
num_tokens = len(data["text"])
return token_length_min < num_tokens < token_length_max

View File

@ -1,42 +0,0 @@
import random
def sample_hotword(
length,
hotword_min_length,
hotword_max_length,
sample_rate,
double_rate,
pre_prob,
pre_index=None,
pre_hwlist=None,
):
if length < hotword_min_length:
return [-1]
if random.random() < sample_rate:
if pre_prob > 0 and random.random() < pre_prob and pre_index is not None:
return pre_index
if length == hotword_min_length:
return [0, length - 1]
elif random.random() < double_rate and length > hotword_max_length + hotword_min_length + 2:
# sample two hotwords in a sentence
_max_hw_length = min(hotword_max_length, length // 2)
# first hotword
start1 = random.randint(0, length // 3)
end1 = random.randint(start1 + hotword_min_length - 1, start1 + _max_hw_length - 1)
# second hotword
start2 = random.randint(end1 + 1, length - hotword_min_length)
end2 = random.randint(
min(length - 1, start2 + hotword_min_length - 1),
min(length - 1, start2 + hotword_max_length - 1),
)
return [start1, end1, start2, end2]
else: # single hotword
start = random.randint(0, length - hotword_min_length)
end = random.randint(
min(length - 1, start + hotword_min_length - 1),
min(length - 1, start + hotword_max_length - 1),
)
return [start, end]
else:
return [-1]

View File

@ -1,30 +0,0 @@
import numpy as np
def build_LFR_features(data, m, n):
"""
Actually, this implements stacking frames and skipping frames.
if m = 1 and n = 1, just return the origin features.
if m = 1 and n > 1, it works like skipping.
if m > 1 and n = 1, it works like stacking but only support right frames.
if m > 1 and n > 1, it works like LFR.
Args:
inputs_batch: inputs is T x D np.ndarray
m: number of frames to stack
n: number of frames to skip
"""
LFR_inputs = []
T = data.shape[0]
T_lfr = int(np.ceil(T / n))
for i in range(T_lfr):
if m <= T - i * n:
LFR_inputs.append(np.hstack(data[i * n : i * n + m]))
else:
num_padding = m - (T - i * n)
frame = np.hstack(data[i * n :])
for _ in range(num_padding):
frame = np.hstack((frame, data[-1]))
LFR_inputs.append(frame)
return np.vstack(LFR_inputs)

View File

@ -1,72 +0,0 @@
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
def padding(data, float_pad_value=0.0, int_pad_value=-1):
assert isinstance(data, list)
assert "key" in data[0]
assert "speech" in data[0] or "text" in data[0]
keys = [x["key"] for x in data]
batch = {}
data_names = data[0].keys()
for data_name in data_names:
if data_name == "key" or data_name == "sampling_rate":
continue
else:
if data_name != "hotword_indxs":
if data[0][data_name].dtype.kind == "i":
pad_value = int_pad_value
tensor_type = torch.int64
else:
pad_value = float_pad_value
tensor_type = torch.float32
tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
tensor_pad = pad_sequence(tensor_list, batch_first=True, padding_value=pad_value)
batch[data_name] = tensor_pad
batch[data_name + "_lengths"] = tensor_lengths
# SAC LABEL INCLUDE
if "hotword_indxs" in batch:
# if hotword indxs in batch
# use it to slice hotwords out
hotword_list = []
hotword_lengths = []
text = batch["text"]
text_lengths = batch["text_lengths"]
hotword_indxs = batch["hotword_indxs"]
dha_pad = torch.ones_like(text) * -1
_, t1 = text.shape
t1 += 1 # TODO: as parameter which is same as predictor_bias
nth_hw = 0
for b, (hotword_indx, one_text, length) in enumerate(
zip(hotword_indxs, text, text_lengths)
):
dha_pad[b][:length] = 8405
if hotword_indx[0] != -1:
start, end = int(hotword_indx[0]), int(hotword_indx[1])
hotword = one_text[start : end + 1]
hotword_list.append(hotword)
hotword_lengths.append(end - start + 1)
dha_pad[b][start : end + 1] = one_text[start : end + 1]
nth_hw += 1
if len(hotword_indx) == 4 and hotword_indx[2] != -1:
# the second hotword if exist
start, end = int(hotword_indx[2]), int(hotword_indx[3])
hotword_list.append(one_text[start : end + 1])
hotword_lengths.append(end - start + 1)
dha_pad[b][start : end + 1] = one_text[start : end + 1]
nth_hw += 1
hotword_list.append(torch.tensor([1]))
hotword_lengths.append(1)
hotword_pad = pad_sequence(hotword_list, batch_first=True, padding_value=0)
batch["hotword_pad"] = hotword_pad
batch["hotword_lengths"] = torch.tensor(hotword_lengths, dtype=torch.int32)
batch["dha_pad"] = dha_pad
del batch["hotword_indxs"]
del batch["hotword_indxs_lengths"]
return keys, batch

View File

@ -1,93 +0,0 @@
#!/usr/bin/env python
import re
import numpy as np
from funasr.datasets.large_datasets.utils.hotword_utils import sample_hotword
def forward_segment(text, seg_dict):
word_list = []
i = 0
while i < len(text):
longest_word = text[i]
for j in range(i + 1, len(text) + 1):
word = text[i:j]
if word in seg_dict:
if len(word) > len(longest_word):
longest_word = word
word_list.append(longest_word)
i += len(longest_word)
return word_list
def seg_tokenize(txt, seg_dict):
pattern = re.compile(r"^[\u4E00-\u9FA50-9]+$")
out_txt = ""
for word in txt:
word = word.lower()
if word in seg_dict:
out_txt += seg_dict[word] + " "
else:
if pattern.match(word):
for char in word:
if char in seg_dict:
out_txt += seg_dict[char] + " "
else:
out_txt += "<unk>" + " "
else:
out_txt += "<unk>" + " "
return out_txt.strip().split()
def tokenize(data, vocab=None, seg_dict=None, punc_dict=None, bpe_tokenizer=None, hw_config=None):
assert "text" in data
assert isinstance(vocab, dict)
text = data["text"]
token = []
vad = -2
if bpe_tokenizer is not None:
text = bpe_tokenizer.text2tokens(" ".join(text))
if seg_dict is not None:
assert isinstance(seg_dict, dict)
text = seg_tokenize(text, seg_dict)
length = len(text)
if "hw_tag" in data:
pre_index = None
if hw_config["pre_hwlist"] is not None and hw_config["pre_prob"] > 0:
# enable preset hotword detect in sampling
for hw in hw_config["pre_hwlist"]:
hw = " ".join(seg_tokenize(hw, seg_dict))
_find = " ".join(text).find(hw)
if _find != -1:
# _find = text[:_find].count(" ") # bpe sometimes
pre_index = [_find, _find + max(hw.count(" "), 1)]
break
hotword_indxs = sample_hotword(length, **hw_config, pre_index=pre_index)
data["hotword_indxs"] = hotword_indxs
del data["hw_tag"]
for i in range(length):
x = text[i]
if i == length - 1 and "punc" in data and x.startswith("vad:"):
vad = x[4:]
if len(vad) == 0:
vad = -1
else:
vad = int(vad)
elif x in vocab:
token.append(vocab[x])
else:
token.append(vocab["<unk>"])
if "punc" in data and punc_dict is not None:
punc_token = []
for punc in data["punc"]:
if punc in punc_dict:
punc_token.append(punc_dict[punc])
else:
punc_token.append(punc_dict["_"])
data["punc"] = np.array(punc_token)
data["text"] = np.array(token)
if vad is not -2:
data["vad_indexes"] = np.array([vad], dtype=np.int64)
return data

View File

@ -85,8 +85,10 @@ def download_from_ms(**kwargs):
install_requirements(requirements)
if kwargs.get("trust_remote_code", False):
from funasr.utils.dynamic_import import import_module_from_path
import model
model_code = kwargs.get("remote_code", "model")
import_module_from_path(model_code)
# from funasr.register import tables
# tables.print("model")

View File

@ -2,6 +2,7 @@ import importlib.util
import importlib.util
import inspect
import os.path
def load_module_from_path(file_path):
@ -18,6 +19,21 @@ def load_module_from_path(file_path):
return module
def import_module_from_path(file_path):
current_working_directory = os.getcwd()
# 获取当前文件所在的目录
file_dir = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
module_name = file_path.split("/")[-1].replace(".py", "")
if len(file_dir) > 0:
os.chdir(file_dir)
importlib.import_module(module_name)
os.chdir(current_working_directory)
#
# def load_module_from_path(module_name, file_path):
# """