diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py index 2ac37b2f9..fa0adeb28 100644 --- a/funasr/datasets/iterable_dataset.py +++ b/funasr/datasets/iterable_dataset.py @@ -20,7 +20,7 @@ import os.path from funasr.datasets.dataset import ESPnetDataset -SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'] +SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm'] def load_kaldi(input): retval = kaldiio.load_mat(input) @@ -60,9 +60,14 @@ def load_bytes(input): array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) return array +def load_pcm(input): + with open(input,"rb") as f: + bytes = f.read() + return load_bytes(bytes) DATA_TYPES = { "sound": lambda x: torchaudio.load(x)[0][0].numpy(), + "pcm": load_pcm, "kaldi_ark": load_kaldi, "bytes": load_bytes, "waveform": lambda x: x, @@ -219,6 +224,9 @@ class IterableESPnetDataset(IterableDataset): if audio_type not in SUPPORT_AUDIO_TYPE_SETS: raise NotImplementedError( f'Not supported audio type: {audio_type}') + if audio_type == "pcm": + _type = "pcm" + func = DATA_TYPES[_type] array = func(value) if self.fs is not None and name == "speech": @@ -318,6 +326,8 @@ class IterableESPnetDataset(IterableDataset): if audio_type not in SUPPORT_AUDIO_TYPE_SETS: raise NotImplementedError( f'Not supported audio type: {audio_type}') + if audio_type == "pcm": + _type = "pcm" func = DATA_TYPES[_type] # Load entry array = func(value)