FunASR/funasr/datasets/collate_fn.py
2022-11-26 21:56:51 +08:00

83 lines
2.5 KiB
Python

from typing import Collection
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.modules.nets_utils import pad_list
class CommonCollateFn:
"""Functor class of common_collate_fn()"""
def __init__(
self,
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
max_sample_size=None
):
assert check_argument_types()
self.float_pad_value = float_pad_value
self.int_pad_value = int_pad_value
self.not_sequence = set(not_sequence)
self.max_sample_size = max_sample_size
def __repr__(self):
return (
f"{self.__class__}(float_pad_value={self.float_pad_value}, "
f"int_pad_value={self.float_pad_value})"
)
def __call__(
self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
return common_collate_fn(
data,
float_pad_value=self.float_pad_value,
int_pad_value=self.int_pad_value,
not_sequence=self.not_sequence,
)
def common_collate_fn(
data: Collection[Tuple[str, Dict[str, np.ndarray]]],
float_pad_value: Union[float, int] = 0.0,
int_pad_value: int = -32768,
not_sequence: Collection[str] = (),
) -> Tuple[List[str], Dict[str, torch.Tensor]]:
"""Concatenate ndarray-list to an array and convert to torch.Tensor.
"""
assert check_argument_types()
uttids = [u for u, _ in data]
data = [d for _, d in data]
assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
assert all(
not k.endswith("_lengths") for k in data[0]
), f"*_lengths is reserved: {list(data[0])}"
output = {}
for key in data[0]:
if data[0][key].dtype.kind == "i":
pad_value = int_pad_value
else:
pad_value = float_pad_value
array_list = [d[key] for d in data]
tensor_list = [torch.from_numpy(a) for a in array_list]
tensor = pad_list(tensor_list, pad_value)
output[key] = tensor
if key not in not_sequence:
lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
output[key + "_lengths"] = lens
output = (uttids, output)
assert check_return_type(output)
return output