diff --git a/funasr/datasets/collate_fn.py b/funasr/datasets/collate_fn.py index d52032f9e..d34d61032 100644 --- a/funasr/datasets/collate_fn.py +++ b/funasr/datasets/collate_fn.py @@ -78,6 +78,58 @@ def common_collate_fn( lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long) output[key + "_lengths"] = lens + output = (uttids, output) + assert check_return_type(output) + return output + +def crop_to_max_size(feature, target_size): + size = len(feature) + diff = size - target_size + if diff <= 0: + return feature + + start = np.random.randint(0, diff + 1) + end = size - diff + start + return feature[start:end] + + +def clipping_collate_fn( + data: Collection[Tuple[str, Dict[str, np.ndarray]]], + max_sample_size=None, + not_sequence: Collection[str] = (), +) -> Tuple[List[str], Dict[str, torch.Tensor]]: + # mainly for pre-training + assert check_argument_types() + uttids = [u for u, _ in data] + data = [d for _, d in data] + + assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" + assert all( + not k.endswith("_lengths") for k in data[0] + ), f"*_lengths is reserved: {list(data[0])}" + + output = {} + for key in data[0]: + array_list = [d[key] for d in data] + tensor_list = [torch.from_numpy(a) for a in array_list] + sizes = [len(s) for s in tensor_list] + if max_sample_size is None: + target_size = min(sizes) + else: + target_size = min(min(sizes), max_sample_size) + tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1]) + for i, (source, size) in enumerate(zip(tensor_list, sizes)): + diff = size - target_size + if diff == 0: + tensor[i] = source + else: + tensor[i] = crop_to_max_size(source, target_size) + output[key] = tensor + + if key not in not_sequence: + lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long) + output[key + "_lengths"] = lens + output = (uttids, output) assert check_return_type(output) return output \ No newline at end of file diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py index 81c136115..2d3ffd5d7 100644 --- a/funasr/datasets/large_datasets/dataset.py +++ b/funasr/datasets/large_datasets/dataset.py @@ -102,6 +102,8 @@ class AudioDataset(IterableDataset): elif data_type == "text" or data_type == "sound": text_reader = open(data_file, "r") reader_list.append(text_reader) + elif data_type == "none": + continue else: raise TypeError("Data type {} is not supported".format(data_type)) diff --git a/funasr/datasets/large_datasets/utils/filter.py b/funasr/datasets/large_datasets/utils/filter.py index 91ba4be73..1260a47c4 100644 --- a/funasr/datasets/large_datasets/utils/filter.py +++ b/funasr/datasets/large_datasets/utils/filter.py @@ -6,13 +6,21 @@ def filter(data, speech_length_max=15000, token_length_min=0, token_length_max=200): - assert "speech" in data - assert "text" in data + assert "speech" in data or "text" in data - if "sampling_rate" in data: - speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000. + if "speech" in data and "text" in data: + if "sampling_rate" in data: + speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000. + else: + speech_length = data["speech"].shape[0] + num_tokens = len(data['text']) + return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max + elif "speech" in data: + if "sampling_rate" in data: + speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000. + else: + speech_length = data["speech"].shape[0] + return speech_length_min < speech_length < speech_length_max else: - speech_length = data["speech"].shape[0] - num_tokens = len(data['text']) - - return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max + num_tokens = len(data['text']) + return token_length_min < num_tokens < token_length_max