diff --git a/.gitignore b/.gitignore index dea463409..4023869b5 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ modelscope samples .ipynb_checkpoints outputs* +emotion2vec* diff --git a/funasr/download/file.py b/funasr/download/file.py new file mode 100644 index 000000000..d93f24c96 --- /dev/null +++ b/funasr/download/file.py @@ -0,0 +1,328 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import contextlib +import os +import tempfile +from abc import ABCMeta, abstractmethod +from pathlib import Path +from typing import Generator, Union + +import requests + + +class Storage(metaclass=ABCMeta): + """Abstract class of storage. + + All backends need to implement two apis: ``read()`` and ``read_text()``. + ``read()`` reads the file as a byte stream and ``read_text()`` reads + the file as texts. + """ + + @abstractmethod + def read(self, filepath: str): + pass + + @abstractmethod + def read_text(self, filepath: str): + pass + + @abstractmethod + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + pass + + @abstractmethod + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + pass + + +class LocalStorage(Storage): + """Local hard disk storage""" + + def read(self, filepath: Union[str, Path]) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + with open(filepath, 'rb') as f: + content = f.read() + return content + + def read_text(self, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + with open(filepath, 'r', encoding=encoding) as f: + value_buf = f.read() + return value_buf + + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``write`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + dirname = os.path.dirname(filepath) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname, exist_ok=True) + + with open(filepath, 'wb') as f: + f.write(obj) + + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``write_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + dirname = os.path.dirname(filepath) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname, exist_ok=True) + + with open(filepath, 'w', encoding=encoding) as f: + f.write(obj) + + @contextlib.contextmanager + def as_local_path( + self, + filepath: Union[str, + Path]) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing.""" + yield filepath + + +class HTTPStorage(Storage): + """HTTP and HTTPS storage.""" + + def read(self, url): + # TODO @wenmeng.zwm add progress bar if file is too large + r = requests.get(url) + r.raise_for_status() + return r.content + + def read_text(self, url): + r = requests.get(url) + r.raise_for_status() + return r.text + + @contextlib.contextmanager + def as_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> storage = HTTPStorage() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with storage.get_local_path('http://path/to/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.read(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def write(self, obj: bytes, url: Union[str, Path]) -> None: + raise NotImplementedError('write is not supported by HTTP Storage') + + def write_text(self, + obj: str, + url: Union[str, Path], + encoding: str = 'utf-8') -> None: + raise NotImplementedError( + 'write_text is not supported by HTTP Storage') + + +class OSSStorage(Storage): + """OSS storage.""" + + def __init__(self, oss_config_file=None): + # read from config file or env var + raise NotImplementedError( + 'OSSStorage.__init__ to be implemented in the future') + + def read(self, filepath): + raise NotImplementedError( + 'OSSStorage.read to be implemented in the future') + + def read_text(self, filepath, encoding='utf-8'): + raise NotImplementedError( + 'OSSStorage.read_text to be implemented in the future') + + @contextlib.contextmanager + def as_local_path( + self, filepath: str) -> Generator[Union[str, Path], None, None]: + """Download a file from ``filepath``. + + ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It + can be called with ``with`` statement, and when exists from the + ``with`` statement, the temporary path will be released. + + Args: + filepath (str): Download a file from ``filepath``. + + Examples: + >>> storage = OSSStorage() + >>> # After existing from the ``with`` clause, + >>> # the path will be removed + >>> with storage.get_local_path('http://path/to/file') as path: + ... # do something here + """ + try: + f = tempfile.NamedTemporaryFile(delete=False) + f.write(self.read(filepath)) + f.close() + yield f.name + finally: + os.remove(f.name) + + def write(self, obj: bytes, filepath: Union[str, Path]) -> None: + raise NotImplementedError( + 'OSSStorage.write to be implemented in the future') + + def write_text(self, + obj: str, + filepath: Union[str, Path], + encoding: str = 'utf-8') -> None: + raise NotImplementedError( + 'OSSStorage.write_text to be implemented in the future') + + +G_STORAGES = {} + + +class File(object): + _prefix_to_storage: dict = { + 'oss': OSSStorage, + 'http': HTTPStorage, + 'https': HTTPStorage, + 'local': LocalStorage, + } + + @staticmethod + def _get_storage(uri): + assert isinstance(uri, + str), f'uri should be str type, but got {type(uri)}' + + if '://' not in uri: + # local path + storage_type = 'local' + else: + prefix, _ = uri.split('://') + storage_type = prefix + + assert storage_type in File._prefix_to_storage, \ + f'Unsupported uri {uri}, valid prefixs: '\ + f'{list(File._prefix_to_storage.keys())}' + + if storage_type not in G_STORAGES: + G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]() + + return G_STORAGES[storage_type] + + @staticmethod + def read(uri: str) -> bytes: + """Read data from a given ``filepath`` with 'rb' mode. + + Args: + filepath (str or Path): Path to read data. + + Returns: + bytes: Expected bytes object. + """ + storage = File._get_storage(uri) + return storage.read(uri) + + @staticmethod + def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str: + """Read data from a given ``filepath`` with 'r' mode. + + Args: + filepath (str or Path): Path to read data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + + Returns: + str: Expected text reading from ``filepath``. + """ + storage = File._get_storage(uri) + return storage.read_text(uri) + + @staticmethod + def write(obj: bytes, uri: Union[str, Path]) -> None: + """Write data to a given ``filepath`` with 'wb' mode. + + Note: + ``write`` will create a directory if the directory of ``filepath`` + does not exist. + + Args: + obj (bytes): Data to be written. + filepath (str or Path): Path to write data. + """ + storage = File._get_storage(uri) + return storage.write(obj, uri) + + @staticmethod + def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None: + """Write data to a given ``filepath`` with 'w' mode. + + Note: + ``write_text`` will create a directory if the directory of + ``filepath`` does not exist. + + Args: + obj (str): Data to be written. + filepath (str or Path): Path to write data. + encoding (str): The encoding format used to open the ``filepath``. + Default: 'utf-8'. + """ + storage = File._get_storage(uri) + return storage.write_text(obj, uri) + + @contextlib.contextmanager + def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]: + """Only for unified API and do nothing.""" + storage = File._get_storage(uri) + with storage.as_local_path(uri) as local_path: + yield local_path diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py deleted file mode 100644 index 364746ad0..000000000 --- a/funasr/utils/asr_utils.py +++ /dev/null @@ -1,359 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. - -import os -import struct -from typing import Any, Dict, List, Union - -import torchaudio -import librosa -import numpy as np -import pkg_resources -from modelscope.utils.logger import get_logger - -logger = get_logger() - -green_color = '\033[1;32m' -red_color = '\033[0;31;40m' -yellow_color = '\033[0;33;40m' -end_color = '\033[0m' - -global_asr_language = 'zh-cn' - -SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm'] - -def get_version(): - return float(pkg_resources.get_distribution('easyasr').version) - - -def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str): - r_audio_fs = None - - if audio_format == 'wav' or audio_format == 'scp': - r_audio_fs = get_sr_from_wav(audio_in) - elif audio_format == 'pcm' and isinstance(audio_in, bytes): - r_audio_fs = get_sr_from_bytes(audio_in) - - return r_audio_fs - - -def type_checking(audio_in: Union[str, bytes], - audio_fs: int = None, - recog_type: str = None, - audio_format: str = None): - r_recog_type = recog_type - r_audio_format = audio_format - r_wav_path = audio_in - - if isinstance(audio_in, str): - assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist' - elif isinstance(audio_in, bytes): - assert len(audio_in) > 0, 'audio in is empty' - r_audio_format = 'pcm' - r_recog_type = 'wav' - - if audio_in is None: - # for raw_inputs - r_recog_type = 'wav' - r_audio_format = 'pcm' - - if r_recog_type is None and audio_in is not None: - # audio_in is wav, recog_type is wav_file - if os.path.isfile(audio_in): - audio_type = os.path.basename(audio_in).lower() - for support_audio_type in SUPPORT_AUDIO_TYPE_SETS: - if audio_type.rfind(".{}".format(support_audio_type)) >= 0: - r_recog_type = 'wav' - r_audio_format = 'wav' - if audio_type.rfind(".scp") >= 0: - r_recog_type = 'wav' - r_audio_format = 'scp' - if r_recog_type is None: - raise NotImplementedError( - f'Not supported audio type: {audio_type}') - - # recog_type is datasets_file - elif os.path.isdir(audio_in): - dir_name = os.path.basename(audio_in) - if 'test' in dir_name: - r_recog_type = 'test' - elif 'dev' in dir_name: - r_recog_type = 'dev' - elif 'train' in dir_name: - r_recog_type = 'train' - - if r_audio_format is None: - if find_file_by_ends(audio_in, '.ark'): - r_audio_format = 'kaldi_ark' - elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends( - audio_in, '.WAV'): - r_audio_format = 'wav' - elif find_file_by_ends(audio_in, '.records'): - r_audio_format = 'tfrecord' - - if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav': - # datasets with kaldi_ark file - r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) - elif r_audio_format == 'tfrecord' and r_recog_type != 'wav': - # datasets with tensorflow records file - r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../')) - elif r_audio_format == 'wav' and r_recog_type != 'wav': - # datasets with waveform files - r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../')) - - return r_recog_type, r_audio_format, r_wav_path - - -def get_sr_from_bytes(wav: bytes): - sr = None - data = wav - if len(data) > 44: - try: - header_fields = {} - header_fields['ChunkID'] = str(data[0:4], 'UTF-8') - header_fields['Format'] = str(data[8:12], 'UTF-8') - header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8') - if header_fields['ChunkID'] == 'RIFF' and header_fields[ - 'Format'] == 'WAVE' and header_fields[ - 'Subchunk1ID'] == 'fmt ': - header_fields['SampleRate'] = struct.unpack('= 0: - if support_audio_type == "pcm": - fs = None - else: - try: - audio, fs = torchaudio.load(fname) - except: - audio, fs = librosa.load(fname) - break - if audio_type.rfind(".scp") >= 0: - with open(fname, encoding="utf-8") as f: - for line in f: - wav_path = line.split()[1] - fs = get_sr_from_wav(wav_path) - if fs is not None: - break - return fs - elif os.path.isdir(fname): - dir_files = os.listdir(fname) - for file in dir_files: - file_path = os.path.join(fname, file) - if os.path.isfile(file_path): - fs = get_sr_from_wav(file_path) - elif os.path.isdir(file_path): - fs = get_sr_from_wav(file_path) - - if fs is not None: - break - - return fs - - -def find_file_by_ends(dir_path: str, ends: str): - dir_files = os.listdir(dir_path) - for file in dir_files: - file_path = os.path.join(dir_path, file) - if os.path.isfile(file_path): - if ends == ".wav" or ends == ".WAV": - audio_type = os.path.basename(file_path).lower() - for support_audio_type in SUPPORT_AUDIO_TYPE_SETS: - if audio_type.rfind(".{}".format(support_audio_type)) >= 0: - return True - raise NotImplementedError( - f'Not supported audio type: {audio_type}') - elif file_path.endswith(ends): - return True - elif os.path.isdir(file_path): - if find_file_by_ends(file_path, ends): - return True - - return False - - -def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]: - dir_files = os.listdir(dir_path) - for file in dir_files: - file_path = os.path.join(dir_path, file) - if os.path.isfile(file_path): - audio_type = os.path.basename(file_path).lower() - for support_audio_type in SUPPORT_AUDIO_TYPE_SETS: - if audio_type.rfind(".{}".format(support_audio_type)) >= 0: - wav_list.append(file_path) - elif os.path.isdir(file_path): - recursion_dir_all_wav(wav_list, file_path) - - return wav_list - -def compute_wer(hyp_list: List[Any], - ref_list: List[Any], - lang: str = None) -> Dict[str, Any]: - assert len(hyp_list) > 0, 'hyp list is empty' - assert len(ref_list) > 0, 'ref list is empty' - - rst = { - 'Wrd': 0, - 'Corr': 0, - 'Ins': 0, - 'Del': 0, - 'Sub': 0, - 'Snt': 0, - 'Err': 0.0, - 'S.Err': 0.0, - 'wrong_words': 0, - 'wrong_sentences': 0 - } - - if lang is None: - lang = global_asr_language - - for h_item in hyp_list: - for r_item in ref_list: - if h_item['key'] == r_item['key']: - out_item = compute_wer_by_line(h_item['value'], - r_item['value'], - lang) - rst['Wrd'] += out_item['nwords'] - rst['Corr'] += out_item['cor'] - rst['wrong_words'] += out_item['wrong'] - rst['Ins'] += out_item['ins'] - rst['Del'] += out_item['del'] - rst['Sub'] += out_item['sub'] - rst['Snt'] += 1 - if out_item['wrong'] > 0: - rst['wrong_sentences'] += 1 - print_wrong_sentence(key=h_item['key'], - hyp=h_item['value'], - ref=r_item['value']) - else: - print_correct_sentence(key=h_item['key'], - hyp=h_item['value'], - ref=r_item['value']) - - break - - if rst['Wrd'] > 0: - rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2) - if rst['Snt'] > 0: - rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2) - - return rst - - -def compute_wer_by_line(hyp: List[str], - ref: List[str], - lang: str = 'zh-cn') -> Dict[str, Any]: - if lang != 'zh-cn': - hyp = hyp.split() - ref = ref.split() - - hyp = list(map(lambda x: x.lower(), hyp)) - ref = list(map(lambda x: x.lower(), ref)) - - len_hyp = len(hyp) - len_ref = len(ref) - - cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) - - ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) - - for i in range(len_hyp + 1): - cost_matrix[i][0] = i - for j in range(len_ref + 1): - cost_matrix[0][j] = j - - for i in range(1, len_hyp + 1): - for j in range(1, len_ref + 1): - if hyp[i - 1] == ref[j - 1]: - cost_matrix[i][j] = cost_matrix[i - 1][j - 1] - else: - substitution = cost_matrix[i - 1][j - 1] + 1 - insertion = cost_matrix[i - 1][j] + 1 - deletion = cost_matrix[i][j - 1] + 1 - - compare_val = [substitution, insertion, deletion] - - min_val = min(compare_val) - operation_idx = compare_val.index(min_val) + 1 - cost_matrix[i][j] = min_val - ops_matrix[i][j] = operation_idx - - match_idx = [] - i = len_hyp - j = len_ref - rst = { - 'nwords': len_ref, - 'cor': 0, - 'wrong': 0, - 'ins': 0, - 'del': 0, - 'sub': 0 - } - while i >= 0 or j >= 0: - i_idx = max(0, i) - j_idx = max(0, j) - - if ops_matrix[i_idx][j_idx] == 0: # correct - if i - 1 >= 0 and j - 1 >= 0: - match_idx.append((j - 1, i - 1)) - rst['cor'] += 1 - - i -= 1 - j -= 1 - - elif ops_matrix[i_idx][j_idx] == 2: # insert - i -= 1 - rst['ins'] += 1 - - elif ops_matrix[i_idx][j_idx] == 3: # delete - j -= 1 - rst['del'] += 1 - - elif ops_matrix[i_idx][j_idx] == 1: # substitute - i -= 1 - j -= 1 - rst['sub'] += 1 - - if i < 0 and j >= 0: - rst['del'] += 1 - elif j < 0 and i >= 0: - rst['ins'] += 1 - - match_idx.reverse() - wrong_cnt = cost_matrix[len_hyp][len_ref] - rst['wrong'] = wrong_cnt - - return rst - - -def print_wrong_sentence(key: str, hyp: str, ref: str): - space = len(key) - print(key + yellow_color + ' ref: ' + ref) - print(' ' * space + red_color + ' hyp: ' + hyp + end_color) - - -def print_correct_sentence(key: str, hyp: str, ref: str): - space = len(key) - print(key + yellow_color + ' ref: ' + ref) - print(' ' * space + green_color + ' hyp: ' + hyp + end_color) - - -def print_progress(percent): - if percent > 1: - percent = 1 - res = int(50 * percent) * '#' - print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='') diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py index 4fb27c074..c5c3ffcf9 100644 --- a/funasr/utils/load_utils.py +++ b/funasr/utils/load_utils.py @@ -9,7 +9,12 @@ import torchaudio import time import logging from torch.nn.utils.rnn import pad_sequence - +try: + from urllib.parse import urlparse + from funasr.download.file import HTTPStorage + import tempfile +except: + print("urllib is not installed, if you infer from url, please install it first.") # def load_audio(data_or_path_or_list, fs: int=16000, audio_fs: int=16000): # # if isinstance(data_or_path_or_list, (list, tuple)): @@ -43,7 +48,8 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: return data_or_path_or_list_ret else: return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs) for audio in data_or_path_or_list] - + if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): + data_or_path_or_list = download_from_url(data_or_path_or_list) if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list) data_or_path_or_list = data_or_path_or_list[0, :] @@ -99,4 +105,21 @@ def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None): if isinstance(data_len, (list, tuple)): data_len = torch.tensor([data_len]) - return data.to(torch.float32), data_len.to(torch.int32) \ No newline at end of file + return data.to(torch.float32), data_len.to(torch.int32) + +def download_from_url(url): + + result = urlparse(url) + file_path = None + if result.scheme is not None and len(result.scheme) > 0: + storage = HTTPStorage() + # bytes + data = storage.read(url) + work_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(work_dir): + os.makedirs(work_dir) + file_path = os.path.join(work_dir, os.path.basename(url)) + with open(file_path, 'wb') as fb: + fb.write(data) + assert file_path is not None, f"failed to download: {url}" + return file_path \ No newline at end of file