funasr1.0

This commit is contained in:
游雁 2024-01-05 11:52:48 +08:00
parent e82bd3d226
commit 32905d8cde
21 changed files with 238 additions and 120 deletions

View File

@ -1,14 +1,13 @@
# download model
local_path_root=./modelscope_models
local_path_root=../modelscope_models
mkdir -p ${local_path_root}
local_path=${local_path_root}/punc_ct-transformer_zh-cn-common-vocab272727-pytorch
git clone https://www.modelscope.cn/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404.git ${local_path}
git clone https://www.modelscope.cn/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch.git ${local_path}
python funasr/bin/inference.py \
+model="${local_path}" \
+input="${local_path}/example/punc_example.txt" \
+output_dir="./outputs/debug" \
+device="cpu" \
+debug="true"
+device="cpu"

View File

@ -5,7 +5,7 @@
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch")
model = AutoModel(model="../modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch")
res = model(input="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav")
res = model(input="../modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav")
print(res)

View File

@ -1,6 +1,6 @@
# download model
local_path_root=./modelscope_models
local_path_root=../modelscope_models
mkdir -p ${local_path_root}
local_path=${local_path_root}/speech_fsmn_vad_zh-cn-16k-common-pytorch
git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git ${local_path}

View File

@ -5,10 +5,10 @@
from funasr import AutoModel
model = AutoModel(model="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model = AutoModel(model="../modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
vad_model="../modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch",
punc_model="../modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
)
res = model(input="../modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/asr_example.wav", batch_size_s=300, batch_size_threshold_s=60)
res = model(input="../modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)

View File

@ -3,8 +3,8 @@
local_path_root=../modelscope_models
mkdir -p ${local_path_root}
local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
git clone https://www.modelscope.cn/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404.git ${local_path}
local_path=${local_path_root}/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch
git clone https://www.modelscope.cn/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path}
local_path_vad=${local_path_root}/speech_fsmn_vad_zh-cn-16k-common-pytorch
git clone https://www.modelscope.cn/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch.git ${local_path_vad}

View File

@ -5,12 +5,12 @@
from funasr import AutoModel
model = AutoModel(model="../modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
model = AutoModel(model="../modelscope_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
)
#vad_model="../modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch",
#punc_model="../modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
res = model(input="../modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav",
res = model(input="../modelscope_models/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav",
hotword='达摩院 磨搭')
print(res)

View File

@ -0,0 +1,42 @@
(简体中文|[English](./README.md))
# 语音识别
> **注意**:
> pipeline 支持 [modelscope模型仓库](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) 中的所有模型进行推理和微调。这里我们以典型模型作为示例来演示使用方法。
## 推理
### 快速使用
#### [Paraformer 模型](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
```python
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
res = model(input="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav")
print(res)
```
### API接口说明
#### AutoModel 定义
- `model`: [模型仓库](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) 中的模型名称,或本地磁盘中的模型路径
- `device`: `cuda`(默认),使用 GPU 进行推理。如果为`cpu`,则使用 CPU 进行推理
- `ncpu`: `None` (默认),设置用于 CPU 内部操作并行性的线程数
- `output_dir`: `None` (默认),如果设置,输出结果的输出路径
- `batch_size`: `1` (默认),解码时的批处理大小
#### AutoModel 推理
- `input`: 要解码的输入,可以是:
- wav文件路径, 例如: asr_example.wav
- pcm文件路径, 例如: asr_example.pcm此时需要指定音频采样率fs默认为16000
- 音频字节数流,例如:麦克风的字节数数据
- wav.scpkaldi 样式的 wav 列表 (`wav_id \t wav_path`), 例如:
```text
asr_example1 ./audios/asr_example1.wav
asr_example2 ./audios/asr_example2.wav
```
在这种输入 `wav.scp` 的情况下,必须设置 `output_dir` 以保存输出结果
- 音频采样点,例如:`audio, rate = soundfile.read("asr_example_zh.wav")`, 数据类型为 numpy.ndarray。支持batch输入类型为list
```[audio_sample1, audio_sample2, ..., audio_sampleN]```
- fbank输入支持组batch。shape为[batch, frames, dim]类型为torch.Tensor例如
- `output_dir`: None (默认),如果设置,输出结果的输出路径

View File

@ -0,0 +1,15 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/modelscope_models/speech_timestamp_prediction-v1-16k-offline")
res = model(input=("/Users/zhifu/funasr_github/test_local/wav.scp",
"/Users/zhifu/funasr_github/test_local/text.txt"),
data_type=("sound", "text"),
batch_size=2,
)
print(res)

View File

@ -0,0 +1,17 @@
# download model
local_path_root=../modelscope_models
mkdir -p ${local_path_root}
local_path=${local_path_root}/speech_timestamp_prediction-v1-16k-offline
git clone https://www.modelscope.cn/damo/speech_timestamp_prediction-v1-16k-offline.git ${local_path}
python funasr/bin/inference.py \
+model="${local_path}" \
+input='["/Users/zhifu/funasr_github/test_local/wav.scp", "/Users/zhifu/funasr_github/test_local/text.txt"]' \
+data_type='["sound", "text"]' \
+output_dir="./outputs/debug" \
+device="cpu" \
+batch_size=2 \
+debug="true"

View File

@ -4,11 +4,11 @@ import torch
import numpy as np
import hydra
import json
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig, OmegaConf, ListConfig
import logging
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_bytes
from funasr.utils.load_utils import load_bytes
from funasr.train_utils.device_funcs import to_device
from tqdm import tqdm
from funasr.train_utils.load_pretrained_model import load_pretrained_model
@ -17,11 +17,11 @@ import random
import string
from funasr.register import tables
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.utils.timestamp_tools import time_stamp_sentence
def build_iter_for_infer(data_in, input_len=None, data_type="sound", key=None):
def build_iter_for_infer(data_in, input_len=None, data_type=None, key=None):
"""
:param input:
@ -58,9 +58,19 @@ def build_iter_for_infer(data_in, input_len=None, data_type="sound", key=None):
key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
data_list = [data_in]
key_list = [key]
elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank]
data_list = data_in
key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
elif isinstance(data_in, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
data_list_tmp = []
for data_in_i, data_type_i in zip(data_in, data_type):
key_list, data_list_i = build_iter_for_infer(data_in=data_in_i, data_type=data_type_i)
data_list_tmp.append(data_list_i)
data_list = []
for item in zip(*data_list_tmp):
data_list.append(item)
else:
# [audio sample point, fbank]
data_list = data_in
key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in)
@ -72,7 +82,16 @@ def build_iter_for_infer(data_in, input_len=None, data_type="sound", key=None):
return key_list, data_list
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
def main_hydra(cfg: DictConfig):
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item
kwargs = to_plain_list(cfg)
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)
@ -125,7 +144,7 @@ class AutoModel:
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
device = "cpu"
kwargs["batch_size"] = 1
# kwargs["batch_size"] = 1
kwargs["device"] = device
if kwargs.get("ncpu", None):
@ -182,8 +201,8 @@ class AutoModel:
data_type = kwargs.get("data_type", "sound")
batch_size = kwargs.get("batch_size", 1)
if kwargs.get("device", "cpu") == "cpu":
batch_size = 1
# if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1
key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type, key=key)
@ -259,7 +278,7 @@ class AutoModel:
key = res[i]["key"]
vadsegments = res[i]["value"]
input_i = data_list[i]
speech = load_audio(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000))
speech = load_audio_and_text_image_video(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000))
speech_lengths = len(speech)
n = len(vadsegments)
data_with_index = [(vadsegments[i], i) for i in range(n)]
@ -398,7 +417,7 @@ class AutoFrontend:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_batch, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),

View File

@ -8,7 +8,7 @@ import torchaudio
import time
import logging
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.register import tables
@tables.register("dataset_classes", "AudioDataset")

View File

@ -1,76 +0,0 @@
import os
import torch
import json
import torch.distributed as dist
import numpy as np
import kaldiio
import librosa
import torchaudio
import time
import logging
from torch.nn.utils.rnn import pad_sequence
def load_audio(audio_or_path_or_list, fs: int=16000, audio_fs: int=16000):
if isinstance(audio_or_path_or_list, (list, tuple)):
return [load_audio(audio, fs=fs, audio_fs=audio_fs) for audio in audio_or_path_or_list]
if isinstance(audio_or_path_or_list, str) and os.path.exists(audio_or_path_or_list):
audio_or_path_or_list, audio_fs = torchaudio.load(audio_or_path_or_list)
audio_or_path_or_list = audio_or_path_or_list[0, :]
elif isinstance(audio_or_path_or_list, np.ndarray): # audio sample point
audio_or_path_or_list = np.squeeze(audio_or_path_or_list) #[n_samples,]
if audio_fs != fs:
resampler = torchaudio.transforms.Resample(audio_fs, fs)
resampled_waveform = resampler(audio_or_path_or_list[None, :])[0, :]
return audio_or_path_or_list
#
# def load_audio_from_list(audio_list, fs: int=16000, audio_fs: int=16000):
# if isinstance(audio_list, (list, tuple)):
# return [load_audio(audio_or_path, fs=fs, audio_fs=audio_fs) for audio_or_path in audio_list]
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None):
# import pdb;
# pdb.set_trace()
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, torch.Tensor):
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, (list, tuple)):
data_list, data_len = [], []
for data_i in data:
if isinstance(data, np.ndarray):
data_i = torch.from_numpy(data_i)
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
# import pdb;
# pdb.set_trace()
if data_type == "sound":
data, data_len = frontend(data, data_len)
if isinstance(data_len, (list, tuple)):
data_len = torch.tensor([data_len])
return data.to(torch.float32), data_len.to(torch.int32)

View File

@ -23,7 +23,7 @@ from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.paraformer.search import Hypothesis
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
@ -243,7 +243,7 @@ class BiCifParaformer(Paraformer):
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),

View File

@ -46,7 +46,7 @@ else:
@contextmanager
def autocast(enabled=True):
yield
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
@ -337,7 +337,7 @@ class ContextualParaformer(Paraformer):
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),

View File

@ -9,7 +9,7 @@ import math
from typing import Optional
import time
from funasr.register import tables
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio,extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video,extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from torch.nn.utils.rnn import pad_sequence
@ -544,7 +544,7 @@ class FsmnVAD(nn.Module):
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),

View File

@ -22,7 +22,7 @@ from funasr.models.paraformer.search import Hypothesis
from torch.cuda.amp import autocast
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank, load_audio_and_text_image_video
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
@ -466,7 +466,7 @@ class Paraformer(nn.Module):
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)

View File

@ -40,7 +40,7 @@ else:
@contextmanager
def autocast(enabled=True):
yield
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
@ -483,7 +483,7 @@ class Paraformer(nn.Module):
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
@ -761,7 +761,7 @@ class BiCifParaformer(Paraformer):
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),

View File

@ -35,7 +35,7 @@ else:
@contextmanager
def autocast(enabled=True):
yield
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
@ -327,7 +327,7 @@ class SeacoParaformer(BiCifParaformer, Paraformer):
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),

View File

@ -45,7 +45,7 @@ else:
@contextmanager
def autocast(enabled=True):
yield
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
@ -517,7 +517,7 @@ class Transducer(nn.Module):
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)

View File

@ -12,7 +12,7 @@ from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy
# from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.train_utils.device_funcs import force_gatherable
from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
@ -392,7 +392,7 @@ class Transformer(nn.Module):
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
audio_sample_list = load_audio_and_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)

102
funasr/utils/load_utils.py Normal file
View File

@ -0,0 +1,102 @@
import os
import torch
import json
import torch.distributed as dist
import numpy as np
import kaldiio
import librosa
import torchaudio
import time
import logging
from torch.nn.utils.rnn import pad_sequence
# def load_audio(audio_or_path_or_list, fs: int=16000, audio_fs: int=16000):
#
# if isinstance(audio_or_path_or_list, (list, tuple)):
# return [load_audio(audio, fs=fs, audio_fs=audio_fs) for audio in audio_or_path_or_list]
#
# if isinstance(audio_or_path_or_list, str) and os.path.exists(audio_or_path_or_list):
# audio_or_path_or_list, audio_fs = torchaudio.load(audio_or_path_or_list)
# audio_or_path_or_list = audio_or_path_or_list[0, :]
# elif isinstance(audio_or_path_or_list, np.ndarray): # audio sample point
# audio_or_path_or_list = np.squeeze(audio_or_path_or_list) #[n_samples,]
#
# if audio_fs != fs:
# resampler = torchaudio.transforms.Resample(audio_fs, fs)
# audio_or_path_or_list = resampler(audio_or_path_or_list[None, :])[0, :]
# return audio_or_path_or_list
def load_audio_and_text_image_video(audio_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type=None, tokenizer=None):
if isinstance(audio_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
data_types = [data_type] * len(audio_or_path_or_list)
audio_or_path_or_list_ret = [[] for d in data_type]
for i, (data_type_i, audio_or_path_or_list_i) in enumerate(zip(data_types, audio_or_path_or_list)):
for j, (data_type_j, audio_or_path_or_list_j) in enumerate(zip(data_type_i, audio_or_path_or_list_i)):
audio_or_path_or_list_j = load_audio_and_text_image_video(audio_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer)
audio_or_path_or_list_ret[j].append(audio_or_path_or_list_j)
return audio_or_path_or_list_ret
else:
return [load_audio_and_text_image_video(audio, fs=fs, audio_fs=audio_fs) for audio in audio_or_path_or_list]
if isinstance(audio_or_path_or_list, str) and os.path.exists(audio_or_path_or_list):
audio_or_path_or_list, audio_fs = torchaudio.load(audio_or_path_or_list)
audio_or_path_or_list = audio_or_path_or_list[0, :]
elif isinstance(audio_or_path_or_list, np.ndarray): # audio sample point
audio_or_path_or_list = np.squeeze(audio_or_path_or_list) # [n_samples,]
elif isinstance(audio_or_path_or_list, str) and data_type is not None and data_type == "text" and tokenizer is not None:
audio_or_path_or_list = tokenizer.encode(audio_or_path_or_list)
if audio_fs != fs and data_type != "text":
resampler = torchaudio.transforms.Resample(audio_fs, fs)
audio_or_path_or_list = resampler(audio_or_path_or_list[None, :])[0, :]
return audio_or_path_or_list
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in 'iu':
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype('float32')
if dtype.kind != 'f':
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None):
# import pdb;
# pdb.set_trace()
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, torch.Tensor):
if len(data.shape) < 2:
data = data[None, :] # data: [batch, N]
data_len = [data.shape[1]] if data_len is None else data_len
elif isinstance(data, (list, tuple)):
data_list, data_len = [], []
for data_i in data:
if isinstance(data, np.ndarray):
data_i = torch.from_numpy(data_i)
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
# import pdb;
# pdb.set_trace()
# if data_type == "sound":
data, data_len = frontend(data, data_len)
if isinstance(data_len, (list, tuple)):
data_len = torch.tensor([data_len])
return data.to(torch.float32), data_len.to(torch.int32)