From 60c8f036e0fd3e29a0334c577f5f7d91f8b01982 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BB=81=E8=BF=B7?= Date: Fri, 17 Mar 2023 20:09:02 +0800 Subject: [PATCH] update audio type check --- funasr/datasets/iterable_dataset.py | 18 ++++------ funasr/utils/asr_utils.py | 52 +++++++++++++++-------------- 2 files changed, 33 insertions(+), 37 deletions(-) diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py index c8c51d458..4b2fb1a61 100644 --- a/funasr/datasets/iterable_dataset.py +++ b/funasr/datasets/iterable_dataset.py @@ -228,13 +228,9 @@ class IterableESPnetDataset(IterableDataset): name = self.path_name_type_list[i][1] _type = self.path_name_type_list[i][2] if _type == "sound": - audio_type = os.path.basename(value).split(".")[-1].lower() - if audio_type not in SUPPORT_AUDIO_TYPE_SETS: - raise NotImplementedError( - f'Not supported audio type: {audio_type}') - if audio_type == "pcm": - _type = "pcm" - + audio_type = os.path.basename(value).lower() + if audio_type.rfind(".pcm") >= 0: + _type = "pcm" func = DATA_TYPES[_type] array = func(value) if self.fs is not None and (name == "speech" or name == "ref_speech"): @@ -336,11 +332,8 @@ class IterableESPnetDataset(IterableDataset): # 2.a. Load data streamingly for value, (path, name, _type) in zip(values, self.path_name_type_list): if _type == "sound": - audio_type = os.path.basename(value).split(".")[-1].lower() - if audio_type not in SUPPORT_AUDIO_TYPE_SETS: - raise NotImplementedError( - f'Not supported audio type: {audio_type}') - if audio_type == "pcm": + audio_type = os.path.basename(value).lower() + if audio_type.rfind(".pcm") >= 0: _type = "pcm" func = DATA_TYPES[_type] # Load entry @@ -392,3 +385,4 @@ class IterableESPnetDataset(IterableDataset): if count == 0: raise RuntimeError("No iteration") + diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py index 0f0e4c356..4067b048b 100644 --- a/funasr/utils/asr_utils.py +++ b/funasr/utils/asr_utils.py @@ -58,14 +58,15 @@ def type_checking(audio_in: Union[str, bytes], if r_recog_type is None and audio_in is not None: # audio_in is wav, recog_type is wav_file if os.path.isfile(audio_in): - audio_type = os.path.basename(audio_in).split(".")[-1].lower() - if audio_type in SUPPORT_AUDIO_TYPE_SETS: - r_recog_type = 'wav' - r_audio_format = 'wav' - elif audio_type == "scp": + audio_type = os.path.basename(audio_in).lower() + for support_audio_type in SUPPORT_AUDIO_TYPE_SETS: + if audio_type.rfind(".{}".format(support_audio_type)) >= 0: + r_recog_type = 'wav' + r_audio_format = 'wav' + if audio_type.rfind(".scp") >= 0: r_recog_type = 'wav' r_audio_format = 'scp' - else: + if r_recog_type is None: raise NotImplementedError( f'Not supported audio type: {audio_type}') @@ -128,13 +129,15 @@ def get_sr_from_bytes(wav: bytes): def get_sr_from_wav(fname: str): fs = None if os.path.isfile(fname): - audio_type = os.path.basename(fname).split(".")[-1].lower() - if audio_type in SUPPORT_AUDIO_TYPE_SETS: - if audio_type == "pcm": - fs = None - else: - audio, fs = torchaudio.load(fname) - elif audio_type == "scp": + audio_type = os.path.basename(fname).lower() + for support_audio_type in SUPPORT_AUDIO_TYPE_SETS: + if audio_type.rfind(".{}".format(support_audio_type)) >= 0: + if support_audio_type == "pcm": + fs = None + else: + audio, fs = torchaudio.load(fname) + break + if audio_type.rfind(".scp") >= 0: with open(fname, encoding="utf-8") as f: for line in f: wav_path = line.split()[1] @@ -147,9 +150,7 @@ def get_sr_from_wav(fname: str): for file in dir_files: file_path = os.path.join(fname, file) if os.path.isfile(file_path): - audio_type = os.path.basename(file_path).split(".")[-1].lower() - if audio_type in SUPPORT_AUDIO_TYPE_SETS: - fs = get_sr_from_wav(file_path) + fs = get_sr_from_wav(file_path) elif os.path.isdir(file_path): fs = get_sr_from_wav(file_path) @@ -165,12 +166,12 @@ def find_file_by_ends(dir_path: str, ends: str): file_path = os.path.join(dir_path, file) if os.path.isfile(file_path): if ends == ".wav" or ends == ".WAV": - audio_type = os.path.basename(file_path).split(".")[-1].lower() - if audio_type in SUPPORT_AUDIO_TYPE_SETS: - return True - else: - raise NotImplementedError( - f'Not supported audio type: {audio_type}') + audio_type = os.path.basename(file_path).lower() + for support_audio_type in SUPPORT_AUDIO_TYPE_SETS: + if audio_type.rfind(".{}".format(support_audio_type)) >= 0: + return True + raise NotImplementedError( + f'Not supported audio type: {audio_type}') elif file_path.endswith(ends): return True elif os.path.isdir(file_path): @@ -185,9 +186,10 @@ def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]: for file in dir_files: file_path = os.path.join(dir_path, file) if os.path.isfile(file_path): - audio_type = os.path.basename(file_path).split(".")[-1].lower() - if audio_type in SUPPORT_AUDIO_TYPE_SETS: - wav_list.append(file_path) + audio_type = os.path.basename(file_path).lower() + for support_audio_type in SUPPORT_AUDIO_TYPE_SETS: + if audio_type.rfind(".{}".format(support_audio_type)) >= 0: + wav_list.append(file_path) elif os.path.isdir(file_path): recursion_dir_all_wav(wav_list, file_path)