funasr1.0 paraformer_streaming WavFrontendOnline

This commit is contained in:
游雁 2024-01-10 17:42:53 +08:00
parent 6eaf50a063
commit 1028a8a036
8 changed files with 172 additions and 129 deletions

View File

@ -0,0 +1,42 @@
(简体中文|[English](./README.md))
# 语音识别
> **注意**:
> pipeline 支持 [modelscope模型仓库](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) 中的所有模型进行推理和微调。这里我们以典型模型作为示例来演示使用方法。
## 推理
### 快速使用
#### [Paraformer 模型](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
```python
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
res = model(input="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav")
print(res)
```
### API接口说明
#### AutoModel 定义
- `model`: [模型仓库](https://alibaba-damo-academy.github.io/FunASR/en/model_zoo/modelscope_models.html#pretrained-models-on-modelscope) 中的模型名称,或本地磁盘中的模型路径
- `device`: `cuda`(默认),使用 GPU 进行推理。如果为`cpu`,则使用 CPU 进行推理
- `ncpu`: `None` (默认),设置用于 CPU 内部操作并行性的线程数
- `output_dir`: `None` (默认),如果设置,输出结果的输出路径
- `batch_size`: `1` (默认),解码时的批处理大小
#### AutoModel 推理
- `input`: 要解码的输入,可以是:
- wav文件路径, 例如: asr_example.wav
- pcm文件路径, 例如: asr_example.pcm此时需要指定音频采样率fs默认为16000
- 音频字节数流,例如:麦克风的字节数数据
- wav.scpkaldi 样式的 wav 列表 (`wav_id \t wav_path`), 例如:
```text
asr_example1 ./audios/asr_example1.wav
asr_example2 ./audios/asr_example2.wav
```
在这种输入 `wav.scp` 的情况下,必须设置 `output_dir` 以保存输出结果
- 音频采样点,例如:`audio, rate = soundfile.read("asr_example_zh.wav")`, 数据类型为 numpy.ndarray。支持batch输入类型为list
```[audio_sample1, audio_sample2, ..., audio_sampleN]```
- fbank输入支持组batch。shape为[batch, frames, dim]类型为torch.Tensor例如
- `output_dir`: None (默认),如果设置,输出结果的输出路径

View File

@ -0,0 +1,38 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
# from funasr import AutoModel
#
# model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revison="v2.0.0")
#
# res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
# print(res)
from funasr import AutoFrontend
frontend = AutoFrontend(model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0")
import soundfile
speech, sample_rate = soundfile.read("/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav")
chunk_size = [0, 10, 5] #[0, 10, 5] 600ms, [0, 8, 4] 480ms
chunk_stride = chunk_size[1] * 960 # 600ms、480ms
# first chunk, 600ms
cache = {}
for i in range(int(len((speech)-1)/chunk_stride+1)):
speech_chunk = speech[i*chunk_stride:(i+1)*chunk_stride]
fbanks = frontend(input=speech_chunk,
batch_size=2,
cache=cache)
# for batch_idx, fbank_dict in enumerate(fbanks):
# res = model(**fbank_dict)
# print(res)

View File

@ -0,0 +1,14 @@
# download model
local_path_root=../modelscope_models
mkdir -p ${local_path_root}
local_path=${local_path_root}/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git ${local_path}
python funasr/bin/train.py \
+model="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+token_list="../modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \
+train_data_set_list="data/list/audio_datasets.jsonl" \
+output_dir="outputs/debug/ckpt/funasr2/exp2" \
+device="cpu"

View File

@ -0,0 +1,11 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.0"
python funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
+output_dir="./outputs/debug" \
+device="cpu" \

View File

@ -391,7 +391,10 @@ class AutoFrontend:
frontend = frontend_class(**kwargs["frontend_conf"])
self.frontend = frontend
if "frontend" in kwargs:
del kwargs["frontend"]
self.kwargs = kwargs
def __call__(self, input, input_len=None, kwargs=None, **cfg):
@ -423,7 +426,7 @@ class AutoFrontend:
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=self.frontend)
frontend=self.frontend, **kwargs)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000

View File

@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from typing import Tuple
import copy
import numpy as np
import torch
import torch.nn as nn
@ -119,7 +119,9 @@ class WavFrontend(nn.Module):
def forward(
self,
input: torch.Tensor,
input_lengths) -> Tuple[torch.Tensor, torch.Tensor]:
input_lengths,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
@ -249,13 +251,13 @@ class WavFrontendOnline(nn.Module):
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
self.waveforms = None
self.reserve_waveforms = None
self.fbanks = None
self.fbanks_lens = None
# self.waveforms = None
# self.reserve_waveforms = None
# self.fbanks = None
# self.fbanks_lens = None
self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
self.input_cache = None
self.lfr_splice_cache = []
# self.input_cache = None
# self.lfr_splice_cache = []
def output_size(self) -> int:
return self.n_mels * self.lfr_m
@ -278,9 +280,6 @@ class WavFrontendOnline(nn.Module):
return inputs.type(torch.float32)
@staticmethod
# inputs tensor has catted the cache tensor
# def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
# is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[
torch.Tensor, torch.Tensor, int]:
"""
@ -319,15 +318,16 @@ class WavFrontendOnline(nn.Module):
def forward_fbank(
self,
input: torch.Tensor,
input_lengths: torch.Tensor
input_lengths: torch.Tensor,
cache: dict = {},
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
if self.input_cache is None:
self.input_cache = torch.empty(0)
input = torch.cat((self.input_cache, input), dim=1)
input = torch.cat((cache["input_cache"], input), dim=1)
frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
# update self.in_cache
self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
cache["input_cache"] = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
waveforms = torch.empty(0)
feats_pad = torch.empty(0)
feats_lens = torch.empty(0)
@ -360,20 +360,19 @@ class WavFrontendOnline(nn.Module):
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
self.fbanks = feats_pad
import copy
self.fbanks_lens = copy.deepcopy(feats_lens)
cache["fbanks"] = feats_pad
cache["fbanks_lens"]= copy.deepcopy(feats_lens)
return waveforms, feats_pad, feats_lens
def get_fbank(self) -> Tuple[torch.Tensor, torch.Tensor]:
return self.fbanks, self.fbanks_lens
def forward_lfr_cmvn(
self,
input: torch.Tensor,
input_lengths: torch.Tensor,
is_final: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
is_final: bool = False,
cache: dict = {},
**kwargs,
):
batch_size = input.size(0)
feats = []
feats_lens = []
@ -383,7 +382,7 @@ class WavFrontendOnline(nn.Module):
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
# mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
mat, cache["lfr_splice_cache"][i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n,
is_final)
if self.cmvn_file is not None:
mat = self.apply_cmvn(mat, self.cmvn)
@ -400,63 +399,68 @@ class WavFrontendOnline(nn.Module):
return feats_pad, feats_lens, lfr_splice_frame_idxs
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False, reset: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if reset:
self.cache_reset()
self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs
):
is_final = kwargs.get("is_final", False)
reset = kwargs.get("reset", False)
if len(cache) == 0 or reset:
self.init_cache(cache)
batch_size = input.shape[0]
assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D
waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths, cache=cache) # input shape: B T D
if feats.shape[0]:
# if self.reserve_waveforms is None and self.lfr_m > 1:
# self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat(
(self.reserve_waveforms, waveforms), dim=1)
if not self.lfr_splice_cache: # 初始化splice_cache
cache["waveforms"] = torch.cat((cache["reserve_waveforms"], waveforms), dim=1)
if not cache["lfr_splice_cache"]: # 初始化splice_cache
for i in range(batch_size):
self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
cache["lfr_splice_cache"].append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
# need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache) # B T D
if feats_lengths[0] + cache["lfr_splice_cache"][0].shape[0] >= self.lfr_m:
lfr_splice_cache_tensor = torch.stack(cache["lfr_splice_cache"]) # B T D
feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
feats_lengths += lfr_splice_cache_tensor[0].shape[0]
frame_from_waveforms = int(
(self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
(cache["waveforms"].shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
minus_frame = (self.lfr_m - 1) // 2 if cache["reserve_waveforms"].numel() == 0 else 0
feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final, cache=cache)
if self.lfr_m == 1:
self.reserve_waveforms = None
cache["reserve_waveforms"] = torch.empty(0)
else:
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
# print('frame_frame: ' + str(frame_from_waveforms))
self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
cache["reserve_waveforms"] = cache["waveforms"][:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
self.waveforms = self.waveforms[:, :sample_length]
cache["waveforms"] = cache["waveforms"][:, :sample_length]
else:
# update self.reserve_waveforms and self.lfr_splice_cache
self.reserve_waveforms = self.waveforms[:,
:-(self.frame_sample_length - self.frame_shift_sample_length)]
cache["reserve_waveforms"] = cache["waveforms"][:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
for i in range(batch_size):
self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
cache["lfr_splice_cache"][i] = torch.cat((cache["lfr_splice_cache"][i], feats[i]), dim=0)
return torch.empty(0), feats_lengths
else:
if is_final:
self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
feats = torch.stack(self.lfr_splice_cache)
cache["waveforms"] = waveforms if cache["reserve_waveforms"].numel() == 0 else cache["reserve_waveforms"]
feats = torch.stack(cache["lfr_splice_cache"])
feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final, cache=cache)
if is_final:
self.cache_reset()
self.init_cache(cache)
return feats, feats_lengths
def get_waveforms(self):
return self.waveforms
def cache_reset(self):
self.reserve_waveforms = None
self.input_cache = None
self.lfr_splice_cache = []
def init_cache(self, cache: dict = {}):
cache["reserve_waveforms"] = torch.empty(0)
cache["input_cache"] = torch.empty(0)
cache["lfr_splice_cache"] = []
cache["waveforms"] = None
cache["fbanks"] = None
cache["fbanks_lens"] = None
return cache
class WavFrontendMel23(nn.Module):

View File

@ -1,69 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
import numpy as np
import torch
def load_cmvn(cmvn_file):
with open(cmvn_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == '<AddShift>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
add_shift_line = line_item[3:(len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == '<Rescale>':
line_item = lines[i + 1].split()
if line_item[0] == '<LearnRateCoef>':
rescale_line = line_item[3:(len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn)
return cmvn
def apply_cmvn(inputs, cmvn_file): # noqa
"""
Apply CMVN with mvn data
"""
device = inputs.device
dtype = inputs.dtype
frame, dim = inputs.shape
cmvn = load_cmvn(cmvn_file)
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
inputs += torch.from_numpy(means).type(dtype).to(device)
inputs *= torch.from_numpy(vars).type(dtype).to(device)
return inputs.type(torch.float32)
def apply_lfr(inputs, lfr_m, lfr_n):
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
inputs = torch.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
else: # process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n:]).view(-1)
for _ in range(num_padding):
frame = torch.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)

View File

@ -68,7 +68,7 @@ def load_bytes(input):
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None):
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
# import pdb;
# pdb.set_trace()
if isinstance(data, np.ndarray):
@ -83,7 +83,7 @@ def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None):
elif isinstance(data, (list, tuple)):
data_list, data_len = [], []
for data_i in data:
if isinstance(data, np.ndarray):
if isinstance(data_i, np.ndarray):
data_i = torch.from_numpy(data_i)
data_list.append(data_i)
data_len.append(data_i.shape[0])
@ -91,7 +91,7 @@ def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None):
# import pdb;
# pdb.set_trace()
# if data_type == "sound":
data, data_len = frontend(data, data_len)
data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)):
data_len = torch.tensor([data_len])