mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update data2vec pretrain: dataset
This commit is contained in:
parent
9befa9e508
commit
55b45487c7
@ -78,6 +78,58 @@ def common_collate_fn(
|
||||
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
|
||||
def crop_to_max_size(feature, target_size):
|
||||
size = len(feature)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return feature
|
||||
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return feature[start:end]
|
||||
|
||||
|
||||
def clipping_collate_fn(
|
||||
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
||||
max_sample_size=None,
|
||||
not_sequence: Collection[str] = (),
|
||||
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||
# mainly for pre-training
|
||||
assert check_argument_types()
|
||||
uttids = [u for u, _ in data]
|
||||
data = [d for _, d in data]
|
||||
|
||||
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
||||
assert all(
|
||||
not k.endswith("_lengths") for k in data[0]
|
||||
), f"*_lengths is reserved: {list(data[0])}"
|
||||
|
||||
output = {}
|
||||
for key in data[0]:
|
||||
array_list = [d[key] for d in data]
|
||||
tensor_list = [torch.from_numpy(a) for a in array_list]
|
||||
sizes = [len(s) for s in tensor_list]
|
||||
if max_sample_size is None:
|
||||
target_size = min(sizes)
|
||||
else:
|
||||
target_size = min(min(sizes), max_sample_size)
|
||||
tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
|
||||
for i, (source, size) in enumerate(zip(tensor_list, sizes)):
|
||||
diff = size - target_size
|
||||
if diff == 0:
|
||||
tensor[i] = source
|
||||
else:
|
||||
tensor[i] = crop_to_max_size(source, target_size)
|
||||
output[key] = tensor
|
||||
|
||||
if key not in not_sequence:
|
||||
lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
|
||||
output[key + "_lengths"] = lens
|
||||
|
||||
output = (uttids, output)
|
||||
assert check_return_type(output)
|
||||
return output
|
||||
@ -102,6 +102,8 @@ class AudioDataset(IterableDataset):
|
||||
elif data_type == "text" or data_type == "sound":
|
||||
text_reader = open(data_file, "r")
|
||||
reader_list.append(text_reader)
|
||||
elif data_type == "none":
|
||||
continue
|
||||
else:
|
||||
raise TypeError("Data type {} is not supported".format(data_type))
|
||||
|
||||
|
||||
@ -6,13 +6,21 @@ def filter(data,
|
||||
speech_length_max=15000,
|
||||
token_length_min=0,
|
||||
token_length_max=200):
|
||||
assert "speech" in data
|
||||
assert "text" in data
|
||||
assert "speech" in data or "text" in data
|
||||
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||
if "speech" in data and "text" in data:
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
num_tokens = len(data['text'])
|
||||
return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
|
||||
elif "speech" in data:
|
||||
if "sampling_rate" in data:
|
||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
return speech_length_min < speech_length < speech_length_max
|
||||
else:
|
||||
speech_length = data["speech"].shape[0]
|
||||
num_tokens = len(data['text'])
|
||||
|
||||
return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
|
||||
num_tokens = len(data['text'])
|
||||
return token_length_min < num_tokens < token_length_max
|
||||
|
||||
Loading…
Reference in New Issue
Block a user