mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update data2vec pretrain: dataset
This commit is contained in:
parent
9befa9e508
commit
55b45487c7
@ -78,6 +78,58 @@ def common_collate_fn(
|
|||||||
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
|
||||||
output[key + "_lengths"] = lens
|
output[key + "_lengths"] = lens
|
||||||
|
|
||||||
|
output = (uttids, output)
|
||||||
|
assert check_return_type(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def crop_to_max_size(feature, target_size):
|
||||||
|
size = len(feature)
|
||||||
|
diff = size - target_size
|
||||||
|
if diff <= 0:
|
||||||
|
return feature
|
||||||
|
|
||||||
|
start = np.random.randint(0, diff + 1)
|
||||||
|
end = size - diff + start
|
||||||
|
return feature[start:end]
|
||||||
|
|
||||||
|
|
||||||
|
def clipping_collate_fn(
|
||||||
|
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
|
||||||
|
max_sample_size=None,
|
||||||
|
not_sequence: Collection[str] = (),
|
||||||
|
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
|
||||||
|
# mainly for pre-training
|
||||||
|
assert check_argument_types()
|
||||||
|
uttids = [u for u, _ in data]
|
||||||
|
data = [d for _, d in data]
|
||||||
|
|
||||||
|
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
|
||||||
|
assert all(
|
||||||
|
not k.endswith("_lengths") for k in data[0]
|
||||||
|
), f"*_lengths is reserved: {list(data[0])}"
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
for key in data[0]:
|
||||||
|
array_list = [d[key] for d in data]
|
||||||
|
tensor_list = [torch.from_numpy(a) for a in array_list]
|
||||||
|
sizes = [len(s) for s in tensor_list]
|
||||||
|
if max_sample_size is None:
|
||||||
|
target_size = min(sizes)
|
||||||
|
else:
|
||||||
|
target_size = min(min(sizes), max_sample_size)
|
||||||
|
tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
|
||||||
|
for i, (source, size) in enumerate(zip(tensor_list, sizes)):
|
||||||
|
diff = size - target_size
|
||||||
|
if diff == 0:
|
||||||
|
tensor[i] = source
|
||||||
|
else:
|
||||||
|
tensor[i] = crop_to_max_size(source, target_size)
|
||||||
|
output[key] = tensor
|
||||||
|
|
||||||
|
if key not in not_sequence:
|
||||||
|
lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
|
||||||
|
output[key + "_lengths"] = lens
|
||||||
|
|
||||||
output = (uttids, output)
|
output = (uttids, output)
|
||||||
assert check_return_type(output)
|
assert check_return_type(output)
|
||||||
return output
|
return output
|
||||||
@ -102,6 +102,8 @@ class AudioDataset(IterableDataset):
|
|||||||
elif data_type == "text" or data_type == "sound":
|
elif data_type == "text" or data_type == "sound":
|
||||||
text_reader = open(data_file, "r")
|
text_reader = open(data_file, "r")
|
||||||
reader_list.append(text_reader)
|
reader_list.append(text_reader)
|
||||||
|
elif data_type == "none":
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
raise TypeError("Data type {} is not supported".format(data_type))
|
raise TypeError("Data type {} is not supported".format(data_type))
|
||||||
|
|
||||||
|
|||||||
@ -6,13 +6,21 @@ def filter(data,
|
|||||||
speech_length_max=15000,
|
speech_length_max=15000,
|
||||||
token_length_min=0,
|
token_length_min=0,
|
||||||
token_length_max=200):
|
token_length_max=200):
|
||||||
assert "speech" in data
|
assert "speech" in data or "text" in data
|
||||||
assert "text" in data
|
|
||||||
|
|
||||||
if "sampling_rate" in data:
|
if "speech" in data and "text" in data:
|
||||||
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
if "sampling_rate" in data:
|
||||||
|
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||||
|
else:
|
||||||
|
speech_length = data["speech"].shape[0]
|
||||||
|
num_tokens = len(data['text'])
|
||||||
|
return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
|
||||||
|
elif "speech" in data:
|
||||||
|
if "sampling_rate" in data:
|
||||||
|
speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
|
||||||
|
else:
|
||||||
|
speech_length = data["speech"].shape[0]
|
||||||
|
return speech_length_min < speech_length < speech_length_max
|
||||||
else:
|
else:
|
||||||
speech_length = data["speech"].shape[0]
|
num_tokens = len(data['text'])
|
||||||
num_tokens = len(data['text'])
|
return token_length_min < num_tokens < token_length_max
|
||||||
|
|
||||||
return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user