From c14169f374a05387f09087b006d1c046f2720d61 Mon Sep 17 00:00:00 2001 From: hnluo Date: Sun, 5 Feb 2023 12:12:03 +0800 Subject: [PATCH] support audio uppersampling and downsampling --- funasr/datasets/iterable_dataset.py | 31 +++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py index 1fc9270c4..2ac37b2f9 100644 --- a/funasr/datasets/iterable_dataset.py +++ b/funasr/datasets/iterable_dataset.py @@ -11,7 +11,6 @@ from typing import Union import kaldiio import numpy as np -import soundfile import torch import torchaudio from torch.utils.data.dataset import IterableDataset @@ -101,6 +100,7 @@ class IterableESPnetDataset(IterableDataset): [str, Dict[str, np.ndarray]], Dict[str, np.ndarray] ] = None, float_dtype: str = "float32", + fs: dict = None, int_dtype: str = "long", key_file: str = None, ): @@ -116,6 +116,7 @@ class IterableESPnetDataset(IterableDataset): self.float_dtype = float_dtype self.int_dtype = int_dtype self.key_file = key_file + self.fs = fs self.debug_info = {} non_iterable_list = [] @@ -175,6 +176,15 @@ class IterableESPnetDataset(IterableDataset): _type = self.path_name_type_list[0][2] func = DATA_TYPES[_type] array = func(value) + if self.fs is not None and name == "speech": + audio_fs = self.fs["audio_fs"] + model_fs = self.fs["model_fs"] + if audio_fs is not None and model_fs is not None: + array = torch.from_numpy(array) + array = array.unsqueeze(0) + array = torchaudio.transforms.Resample(orig_freq=audio_fs, + new_freq=model_fs)(array) + array = array.squeeze(0).numpy() data[name] = array if self.preprocess is not None: @@ -211,6 +221,15 @@ class IterableESPnetDataset(IterableDataset): f'Not supported audio type: {audio_type}') func = DATA_TYPES[_type] array = func(value) + if self.fs is not None and name == "speech": + audio_fs = self.fs["audio_fs"] + model_fs = self.fs["model_fs"] + if audio_fs is not None and model_fs is not None: + array = torch.from_numpy(array) + array = array.unsqueeze(0) + array = torchaudio.transforms.Resample(orig_freq=audio_fs, + new_freq=model_fs)(array) + array = array.squeeze(0).numpy() data[name] = array if self.preprocess is not None: @@ -302,6 +321,15 @@ class IterableESPnetDataset(IterableDataset): func = DATA_TYPES[_type] # Load entry array = func(value) + if self.fs is not None and name == "speech": + audio_fs = self.fs["audio_fs"] + model_fs = self.fs["model_fs"] + if audio_fs is not None and model_fs is not None: + array = torch.from_numpy(array) + array = array.unsqueeze(0) + array = torchaudio.transforms.Resample(orig_freq=audio_fs, + new_freq=model_fs)(array) + array = array.squeeze(0).numpy() data[name] = array if self.non_iterable_dataset is not None: # 2.b. Load data from non-iterable dataset @@ -335,4 +363,3 @@ class IterableESPnetDataset(IterableDataset): if count == 0: raise RuntimeError("No iteration") -