funasr1.0 infer url modelscope

This commit is contained in:
游雁 2024-01-09 00:13:51 +08:00
parent 0a53be28e2
commit f14f9f8d15
4 changed files with 355 additions and 362 deletions

1
.gitignore vendored
View File

@ -22,3 +22,4 @@ modelscope
samples
.ipynb_checkpoints
outputs*
emotion2vec*

328
funasr/download/file.py Normal file
View File

@ -0,0 +1,328 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import os
import tempfile
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Generator, Union
import requests
class Storage(metaclass=ABCMeta):
"""Abstract class of storage.
All backends need to implement two apis: ``read()`` and ``read_text()``.
``read()`` reads the file as a byte stream and ``read_text()`` reads
the file as texts.
"""
@abstractmethod
def read(self, filepath: str):
pass
@abstractmethod
def read_text(self, filepath: str):
pass
@abstractmethod
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
pass
@abstractmethod
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
pass
class LocalStorage(Storage):
"""Local hard disk storage"""
def read(self, filepath: Union[str, Path]) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
with open(filepath, 'rb') as f:
content = f.read()
return content
def read_text(self,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
with open(filepath, 'r', encoding=encoding) as f:
value_buf = f.read()
return value_buf
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``write`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
with open(filepath, 'wb') as f:
f.write(obj)
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
dirname = os.path.dirname(filepath)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname, exist_ok=True)
with open(filepath, 'w', encoding=encoding) as f:
f.write(obj)
@contextlib.contextmanager
def as_local_path(
self,
filepath: Union[str,
Path]) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
yield filepath
class HTTPStorage(Storage):
"""HTTP and HTTPS storage."""
def read(self, url):
# TODO @wenmeng.zwm add progress bar if file is too large
r = requests.get(url)
r.raise_for_status()
return r.content
def read_text(self, url):
r = requests.get(url)
r.raise_for_status()
return r.text
@contextlib.contextmanager
def as_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = HTTPStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
def write(self, obj: bytes, url: Union[str, Path]) -> None:
raise NotImplementedError('write is not supported by HTTP Storage')
def write_text(self,
obj: str,
url: Union[str, Path],
encoding: str = 'utf-8') -> None:
raise NotImplementedError(
'write_text is not supported by HTTP Storage')
class OSSStorage(Storage):
"""OSS storage."""
def __init__(self, oss_config_file=None):
# read from config file or env var
raise NotImplementedError(
'OSSStorage.__init__ to be implemented in the future')
def read(self, filepath):
raise NotImplementedError(
'OSSStorage.read to be implemented in the future')
def read_text(self, filepath, encoding='utf-8'):
raise NotImplementedError(
'OSSStorage.read_text to be implemented in the future')
@contextlib.contextmanager
def as_local_path(
self, filepath: str) -> Generator[Union[str, Path], None, None]:
"""Download a file from ``filepath``.
``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
can be called with ``with`` statement, and when exists from the
``with`` statement, the temporary path will be released.
Args:
filepath (str): Download a file from ``filepath``.
Examples:
>>> storage = OSSStorage()
>>> # After existing from the ``with`` clause,
>>> # the path will be removed
>>> with storage.get_local_path('http://path/to/file') as path:
... # do something here
"""
try:
f = tempfile.NamedTemporaryFile(delete=False)
f.write(self.read(filepath))
f.close()
yield f.name
finally:
os.remove(f.name)
def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
raise NotImplementedError(
'OSSStorage.write to be implemented in the future')
def write_text(self,
obj: str,
filepath: Union[str, Path],
encoding: str = 'utf-8') -> None:
raise NotImplementedError(
'OSSStorage.write_text to be implemented in the future')
G_STORAGES = {}
class File(object):
_prefix_to_storage: dict = {
'oss': OSSStorage,
'http': HTTPStorage,
'https': HTTPStorage,
'local': LocalStorage,
}
@staticmethod
def _get_storage(uri):
assert isinstance(uri,
str), f'uri should be str type, but got {type(uri)}'
if '://' not in uri:
# local path
storage_type = 'local'
else:
prefix, _ = uri.split('://')
storage_type = prefix
assert storage_type in File._prefix_to_storage, \
f'Unsupported uri {uri}, valid prefixs: '\
f'{list(File._prefix_to_storage.keys())}'
if storage_type not in G_STORAGES:
G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
return G_STORAGES[storage_type]
@staticmethod
def read(uri: str) -> bytes:
"""Read data from a given ``filepath`` with 'rb' mode.
Args:
filepath (str or Path): Path to read data.
Returns:
bytes: Expected bytes object.
"""
storage = File._get_storage(uri)
return storage.read(uri)
@staticmethod
def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
"""Read data from a given ``filepath`` with 'r' mode.
Args:
filepath (str or Path): Path to read data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
Returns:
str: Expected text reading from ``filepath``.
"""
storage = File._get_storage(uri)
return storage.read_text(uri)
@staticmethod
def write(obj: bytes, uri: Union[str, Path]) -> None:
"""Write data to a given ``filepath`` with 'wb' mode.
Note:
``write`` will create a directory if the directory of ``filepath``
does not exist.
Args:
obj (bytes): Data to be written.
filepath (str or Path): Path to write data.
"""
storage = File._get_storage(uri)
return storage.write(obj, uri)
@staticmethod
def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
"""Write data to a given ``filepath`` with 'w' mode.
Note:
``write_text`` will create a directory if the directory of
``filepath`` does not exist.
Args:
obj (str): Data to be written.
filepath (str or Path): Path to write data.
encoding (str): The encoding format used to open the ``filepath``.
Default: 'utf-8'.
"""
storage = File._get_storage(uri)
return storage.write_text(obj, uri)
@contextlib.contextmanager
def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
"""Only for unified API and do nothing."""
storage = File._get_storage(uri)
with storage.as_local_path(uri) as local_path:
yield local_path

View File

@ -1,359 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import struct
from typing import Any, Dict, List, Union
import torchaudio
import librosa
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
logger = get_logger()
green_color = '\033[1;32m'
red_color = '\033[0;31;40m'
yellow_color = '\033[0;33;40m'
end_color = '\033[0m'
global_asr_language = 'zh-cn'
SUPPORT_AUDIO_TYPE_SETS = ['flac', 'mp3', 'ogg', 'opus', 'wav', 'pcm']
def get_version():
return float(pkg_resources.get_distribution('easyasr').version)
def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
r_audio_fs = None
if audio_format == 'wav' or audio_format == 'scp':
r_audio_fs = get_sr_from_wav(audio_in)
elif audio_format == 'pcm' and isinstance(audio_in, bytes):
r_audio_fs = get_sr_from_bytes(audio_in)
return r_audio_fs
def type_checking(audio_in: Union[str, bytes],
audio_fs: int = None,
recog_type: str = None,
audio_format: str = None):
r_recog_type = recog_type
r_audio_format = audio_format
r_wav_path = audio_in
if isinstance(audio_in, str):
assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
elif isinstance(audio_in, bytes):
assert len(audio_in) > 0, 'audio in is empty'
r_audio_format = 'pcm'
r_recog_type = 'wav'
if audio_in is None:
# for raw_inputs
r_recog_type = 'wav'
r_audio_format = 'pcm'
if r_recog_type is None and audio_in is not None:
# audio_in is wav, recog_type is wav_file
if os.path.isfile(audio_in):
audio_type = os.path.basename(audio_in).lower()
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
r_recog_type = 'wav'
r_audio_format = 'wav'
if audio_type.rfind(".scp") >= 0:
r_recog_type = 'wav'
r_audio_format = 'scp'
if r_recog_type is None:
raise NotImplementedError(
f'Not supported audio type: {audio_type}')
# recog_type is datasets_file
elif os.path.isdir(audio_in):
dir_name = os.path.basename(audio_in)
if 'test' in dir_name:
r_recog_type = 'test'
elif 'dev' in dir_name:
r_recog_type = 'dev'
elif 'train' in dir_name:
r_recog_type = 'train'
if r_audio_format is None:
if find_file_by_ends(audio_in, '.ark'):
r_audio_format = 'kaldi_ark'
elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
audio_in, '.WAV'):
r_audio_format = 'wav'
elif find_file_by_ends(audio_in, '.records'):
r_audio_format = 'tfrecord'
if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
# datasets with kaldi_ark file
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
# datasets with tensorflow records file
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
elif r_audio_format == 'wav' and r_recog_type != 'wav':
# datasets with waveform files
r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
return r_recog_type, r_audio_format, r_wav_path
def get_sr_from_bytes(wav: bytes):
sr = None
data = wav
if len(data) > 44:
try:
header_fields = {}
header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
header_fields['Format'] = str(data[8:12], 'UTF-8')
header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
if header_fields['ChunkID'] == 'RIFF' and header_fields[
'Format'] == 'WAVE' and header_fields[
'Subchunk1ID'] == 'fmt ':
header_fields['SampleRate'] = struct.unpack('<I',
data[24:28])[0]
sr = header_fields['SampleRate']
except Exception:
# no treatment
pass
else:
logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
return sr
def get_sr_from_wav(fname: str):
fs = None
if os.path.isfile(fname):
audio_type = os.path.basename(fname).lower()
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
if support_audio_type == "pcm":
fs = None
else:
try:
audio, fs = torchaudio.load(fname)
except:
audio, fs = librosa.load(fname)
break
if audio_type.rfind(".scp") >= 0:
with open(fname, encoding="utf-8") as f:
for line in f:
wav_path = line.split()[1]
fs = get_sr_from_wav(wav_path)
if fs is not None:
break
return fs
elif os.path.isdir(fname):
dir_files = os.listdir(fname)
for file in dir_files:
file_path = os.path.join(fname, file)
if os.path.isfile(file_path):
fs = get_sr_from_wav(file_path)
elif os.path.isdir(file_path):
fs = get_sr_from_wav(file_path)
if fs is not None:
break
return fs
def find_file_by_ends(dir_path: str, ends: str):
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
if ends == ".wav" or ends == ".WAV":
audio_type = os.path.basename(file_path).lower()
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
return True
raise NotImplementedError(
f'Not supported audio type: {audio_type}')
elif file_path.endswith(ends):
return True
elif os.path.isdir(file_path):
if find_file_by_ends(file_path, ends):
return True
return False
def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
dir_files = os.listdir(dir_path)
for file in dir_files:
file_path = os.path.join(dir_path, file)
if os.path.isfile(file_path):
audio_type = os.path.basename(file_path).lower()
for support_audio_type in SUPPORT_AUDIO_TYPE_SETS:
if audio_type.rfind(".{}".format(support_audio_type)) >= 0:
wav_list.append(file_path)
elif os.path.isdir(file_path):
recursion_dir_all_wav(wav_list, file_path)
return wav_list
def compute_wer(hyp_list: List[Any],
ref_list: List[Any],
lang: str = None) -> Dict[str, Any]:
assert len(hyp_list) > 0, 'hyp list is empty'
assert len(ref_list) > 0, 'ref list is empty'
rst = {
'Wrd': 0,
'Corr': 0,
'Ins': 0,
'Del': 0,
'Sub': 0,
'Snt': 0,
'Err': 0.0,
'S.Err': 0.0,
'wrong_words': 0,
'wrong_sentences': 0
}
if lang is None:
lang = global_asr_language
for h_item in hyp_list:
for r_item in ref_list:
if h_item['key'] == r_item['key']:
out_item = compute_wer_by_line(h_item['value'],
r_item['value'],
lang)
rst['Wrd'] += out_item['nwords']
rst['Corr'] += out_item['cor']
rst['wrong_words'] += out_item['wrong']
rst['Ins'] += out_item['ins']
rst['Del'] += out_item['del']
rst['Sub'] += out_item['sub']
rst['Snt'] += 1
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
print_wrong_sentence(key=h_item['key'],
hyp=h_item['value'],
ref=r_item['value'])
else:
print_correct_sentence(key=h_item['key'],
hyp=h_item['value'],
ref=r_item['value'])
break
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
if rst['Snt'] > 0:
rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
return rst
def compute_wer_by_line(hyp: List[str],
ref: List[str],
lang: str = 'zh-cn') -> Dict[str, Any]:
if lang != 'zh-cn':
hyp = hyp.split()
ref = ref.split()
hyp = list(map(lambda x: x.lower(), hyp))
ref = list(map(lambda x: x.lower(), ref))
len_hyp = len(hyp)
len_ref = len(ref)
cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
for i in range(len_hyp + 1):
cost_matrix[i][0] = i
for j in range(len_ref + 1):
cost_matrix[0][j] = j
for i in range(1, len_hyp + 1):
for j in range(1, len_ref + 1):
if hyp[i - 1] == ref[j - 1]:
cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
else:
substitution = cost_matrix[i - 1][j - 1] + 1
insertion = cost_matrix[i - 1][j] + 1
deletion = cost_matrix[i][j - 1] + 1
compare_val = [substitution, insertion, deletion]
min_val = min(compare_val)
operation_idx = compare_val.index(min_val) + 1
cost_matrix[i][j] = min_val
ops_matrix[i][j] = operation_idx
match_idx = []
i = len_hyp
j = len_ref
rst = {
'nwords': len_ref,
'cor': 0,
'wrong': 0,
'ins': 0,
'del': 0,
'sub': 0
}
while i >= 0 or j >= 0:
i_idx = max(0, i)
j_idx = max(0, j)
if ops_matrix[i_idx][j_idx] == 0: # correct
if i - 1 >= 0 and j - 1 >= 0:
match_idx.append((j - 1, i - 1))
rst['cor'] += 1
i -= 1
j -= 1
elif ops_matrix[i_idx][j_idx] == 2: # insert
i -= 1
rst['ins'] += 1
elif ops_matrix[i_idx][j_idx] == 3: # delete
j -= 1
rst['del'] += 1
elif ops_matrix[i_idx][j_idx] == 1: # substitute
i -= 1
j -= 1
rst['sub'] += 1
if i < 0 and j >= 0:
rst['del'] += 1
elif j < 0 and i >= 0:
rst['ins'] += 1
match_idx.reverse()
wrong_cnt = cost_matrix[len_hyp][len_ref]
rst['wrong'] = wrong_cnt
return rst
def print_wrong_sentence(key: str, hyp: str, ref: str):
space = len(key)
print(key + yellow_color + ' ref: ' + ref)
print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
def print_correct_sentence(key: str, hyp: str, ref: str):
space = len(key)
print(key + yellow_color + ' ref: ' + ref)
print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
def print_progress(percent):
if percent > 1:
percent = 1
res = int(50 * percent) * '#'
print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')

View File

@ -9,7 +9,12 @@ import torchaudio
import time
import logging
from torch.nn.utils.rnn import pad_sequence
try:
from urllib.parse import urlparse
from funasr.download.file import HTTPStorage
import tempfile
except:
print("urllib is not installed, if you infer from url, please install it first.")
# def load_audio(data_or_path_or_list, fs: int=16000, audio_fs: int=16000):
#
# if isinstance(data_or_path_or_list, (list, tuple)):
@ -43,7 +48,8 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs:
return data_or_path_or_list_ret
else:
return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs) for audio in data_or_path_or_list]
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'):
data_or_path_or_list = download_from_url(data_or_path_or_list)
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list):
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
data_or_path_or_list = data_or_path_or_list[0, :]
@ -99,4 +105,21 @@ def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None):
if isinstance(data_len, (list, tuple)):
data_len = torch.tensor([data_len])
return data.to(torch.float32), data_len.to(torch.int32)
return data.to(torch.float32), data_len.to(torch.int32)
def download_from_url(url):
result = urlparse(url)
file_path = None
if result.scheme is not None and len(result.scheme) > 0:
storage = HTTPStorage()
# bytes
data = storage.read(url)
work_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(work_dir):
os.makedirs(work_dir)
file_path = os.path.join(work_dir, os.path.basename(url))
with open(file_path, 'wb') as fb:
fb.write(data)
assert file_path is not None, f"failed to download: {url}"
return file_path