export model

This commit is contained in:
游雁 2023-02-13 17:43:01 +08:00
parent de264be093
commit 865ae89f0a
20 changed files with 427 additions and 53 deletions

123
fbank.py Normal file
View File

@ -0,0 +1,123 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from typing import Tuple
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
import kaldi_native_fbank as knf
class WavFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
self.n_mels = n_mels
self.frame_length = frame_length
self.frame_shift = frame_shift
self.filter_length_min = filter_length_min
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
def output_size(self) -> int:
return self.n_mels * self.lfr_m
def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform * (1 << 15)
waveform = waveform.unsqueeze(0)
mat = kaldi.fbank(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
import kaldi_native_fbank as knf
def fbank_knf(waveform):
# sampling_rate = 16000
# samples = torch.randn(16000 * 10)
opts = knf.FbankOptions()
opts.frame_opts.samp_freq = 16000
opts.frame_opts.dither = 0.0
opts.frame_opts.window_type = "hamming"
opts.frame_opts.frame_shift_ms = 10.0
opts.frame_opts.frame_length_ms = 25.0
opts.mel_opts.num_bins = 80
opts.energy_floor = 1
opts.frame_opts.snip_edges = True
opts.mel_opts.debug_mel = False
fbank = knf.OnlineFbank(opts)
waveform = waveform * (1 << 15)
fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist())
frames = fbank.num_frames_ready
mat = np.empty([frames, opts.mel_opts.num_bins])
for i in range(frames):
mat[i, :] = fbank.get_frame(i)
return mat
if __name__ == '__main__':
import librosa
path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
waveform, fs = librosa.load(path, sr=None)
fbank = fbank_knf(waveform)
frontend = WavFrontend(dither=0.0)
waveform_tensor = torch.from_numpy(waveform)[None, :]
fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)])
fbank_torch = fbank_torch.cpu().numpy()[0, :, :]
diff = fbank - fbank_torch
diff_max = diff.max()
diff_sum = diff.abs().sum()
pass

View File

@ -171,10 +171,7 @@ class WavFrontend(AbsFrontend):
window_type=self.window,
sample_frequency=self.fs)
# if self.lfr_m != 1 or self.lfr_n != 1:
# mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
# if self.cmvn_file is not None:
# mat = apply_cmvn(mat, self.cmvn_file)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)

View File

View File

View File

@ -29,12 +29,6 @@
│   └── utils.py
├── README.md
├── requirements.txt
├── resources
│   ├── config.yaml
│   └── models
│   ├── am.mvn
│   ├── model.onnx # Put it here.
│   └── token_list.pkl
├── test_onnx.py
├── tests
│   ├── __pycache__
@ -48,15 +42,15 @@
- Output: `List[str]`: recognition result.
- Example:
```python
from rapid_paraformer import RapidParaformer
from paraformer_onnx import Paraformer
config_path = 'resources/config.yaml'
paraformer = RapidParaformer(config_path)
model = Paraformer(config_path)
wav_path = ['test_wavs/0478_00017.wav']
wav_path = ['example/asr_example.wav']
result = paraformer(wav_path)
result = model(wav_path)
print(result)
```

View File

@ -1,6 +1,7 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os.path
import traceback
from pathlib import Path
from typing import List, Union, Tuple
@ -11,25 +12,33 @@ import numpy as np
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
read_yaml)
from .postprocess_utils import sentence_postprocess
logging = get_logger()
class RapidParaformer():
def __init__(self, config_path: Union[str, Path]) -> None:
if not Path(config_path).exists():
raise FileNotFoundError(f'{config_path} does not exist.')
class Paraformer():
def __init__(self, model_dir: Union[str, Path]=None,
batch_size: int = 1,
device_id: Union[str, int]="-1",
):
if not Path(model_dir).exists():
raise FileNotFoundError(f'{model_dir} does not exist.')
config = read_yaml(config_path)
model_file = os.path.join(model_dir, 'model.onnx')
config_file = os.path.join(model_dir, 'config.yaml')
cmvn_file = os.path.join(model_dir, 'am.mvn')
config = read_yaml(config_file)
self.converter = TokenIDConverter(**config['TokenIDConverter'])
self.tokenizer = CharTokenizer(**config['CharTokenizer'])
self.converter = TokenIDConverter(config['token_list'])
self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(
cmvn_file=config['WavFrontend']['cmvn_file'],
**config['WavFrontend']['frontend_conf']
cmvn_file=cmvn_file,
**config['frontend_conf']
)
self.ort_infer = OrtInferSession(config['Model'])
self.batch_size = config['Model']['batch_size']
self.ort_infer = OrtInferSession(model_file, device_id)
self.batch_size = batch_size
def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
waveform_list = self.load_data(wav_content)
@ -124,16 +133,19 @@ class RapidParaformer():
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
text = self.tokenizer.tokens2text(token)
token = token[:valid_token_num-1]
texts = sentence_postprocess(token)
text = texts[0]
# text = self.tokenizer.tokens2text(token)
return text[:valid_token_num-1]
if __name__ == '__main__':
project_dir = Path(__file__).resolve().parent.parent
cfg_path = project_dir / 'resources' / 'config.yaml'
paraformer = RapidParaformer(cfg_path)
model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model = Paraformer(model_dir)
wav_file = os.path.join(model_dir, 'example/asr_example.wav')
result = model(wav_file)
print(result)
wav_file = '0478_00017.wav'
for i in range(1000):
result = paraformer(wav_file)
print(result)

View File

@ -0,0 +1,240 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import string
import logging
from typing import Any, List, Union
def isChinese(ch: str):
if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039':
return True
return False
def isAllChinese(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if isChinese(ch) is False:
return False
return True
def isAllAlpha(word: Union[List[Any], str]):
word_lists = []
for i in word:
cur = i.replace(' ', '')
cur = cur.replace('</s>', '')
cur = cur.replace('<s>', '')
word_lists.append(cur)
if len(word_lists) == 0:
return False
for ch in word_lists:
if ch.isalpha() is False and ch != "'":
return False
elif ch.isalpha() is True and isChinese(ch) is True:
return False
return True
# def abbr_dispose(words: List[Any]) -> List[Any]:
def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
words_size = len(words)
word_lists = []
abbr_begin = []
abbr_end = []
last_num = -1
ts_lists = []
ts_nums = []
ts_index = 0
for num in range(words_size):
if num <= last_num:
continue
if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
if num + 1 < words_size and words[
num + 1] == ' ' and num + 2 < words_size and len(
words[num +
2]) == 1 and words[num +
2].encode('utf-8').isalpha():
# found the begin of abbr
abbr_begin.append(num)
num += 2
abbr_end.append(num)
# to find the end of abbr
while True:
num += 1
if num < words_size and words[num] == ' ':
num += 1
if num < words_size and len(
words[num]) == 1 and words[num].encode(
'utf-8').isalpha():
abbr_end.pop()
abbr_end.append(num)
last_num = num
else:
break
else:
break
for num in range(words_size):
if words[num] == ' ':
ts_nums.append(ts_index)
else:
ts_nums.append(ts_index)
ts_index += 1
last_num = -1
for num in range(words_size):
if num <= last_num:
continue
if num in abbr_begin:
if time_stamp is not None:
begin = time_stamp[ts_nums[num]][0]
word_lists.append(words[num].upper())
num += 1
while num < words_size:
if num in abbr_end:
word_lists.append(words[num].upper())
last_num = num
break
else:
if words[num].encode('utf-8').isalpha():
word_lists.append(words[num].upper())
num += 1
if time_stamp is not None:
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
else:
word_lists.append(words[num])
if time_stamp is not None and words[num] != ' ':
begin = time_stamp[ts_nums[num]][0]
end = time_stamp[ts_nums[num]][1]
ts_lists.append([begin, end])
begin = end
if time_stamp is not None:
return word_lists, ts_lists
else:
return word_lists
def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
middle_lists = []
word_lists = []
word_item = ''
ts_lists = []
# wash words lists
for i in words:
word = ''
if isinstance(i, str):
word = i
else:
word = i.decode('utf-8')
if word in ['<s>', '</s>', '<unk>']:
continue
else:
middle_lists.append(word)
# all chinese characters
if isAllChinese(middle_lists):
for i, ch in enumerate(middle_lists):
word_lists.append(ch.replace(' ', ''))
if time_stamp is not None:
ts_lists = time_stamp
# all alpha characters
elif isAllAlpha(middle_lists):
ts_flag = True
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ''
if '@@' in ch:
word = ch.replace('@@', '')
word_item += word
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
else:
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
# mix characters
else:
alpha_blank = False
ts_flag = True
begin = -1
end = -1
for i, ch in enumerate(middle_lists):
if ts_flag and time_stamp is not None:
begin = time_stamp[i][0]
end = time_stamp[i][1]
word = ''
if isAllChinese(ch):
if alpha_blank is True:
word_lists.pop()
word_lists.append(ch)
alpha_blank = False
if time_stamp is not None:
ts_flag = True
ts_lists.append([begin, end])
begin = end
elif '@@' in ch:
word = ch.replace('@@', '')
word_item += word
alpha_blank = False
if time_stamp is not None:
ts_flag = False
end = time_stamp[i][1]
elif isAllAlpha(ch):
word_item += ch
word_lists.append(word_item)
word_lists.append(' ')
word_item = ''
alpha_blank = True
if time_stamp is not None:
ts_flag = True
end = time_stamp[i][1]
ts_lists.append([begin, end])
begin = end
else:
raise ValueError('invalid character: {}'.format(ch))
if time_stamp is not None:
word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
real_word_lists = []
for ch in word_lists:
if ch != ' ':
real_word_lists.append(ch)
sentence = ' '.join(real_word_lists).strip()
return sentence, ts_lists, real_word_lists
else:
word_lists = abbr_dispose(word_lists)
real_word_lists = []
for ch in word_lists:
if ch != ' ':
real_word_lists.append(ch)
sentence = ''.join(word_lists).strip()
return sentence, real_word_lists

View File

@ -14,6 +14,7 @@ from onnxruntime import (GraphOptimizationLevel, InferenceSession,
from typeguard import check_argument_types
from .kaldifeat import compute_fbank_feats
import warnings
root_dir = Path(__file__).resolve().parent
@ -21,24 +22,25 @@ logger_initialized = {}
class TokenIDConverter():
def __init__(self, token_path: Union[Path, str],
def __init__(self, token_list: Union[Path, str],
unk_symbol: str = "<unk>",):
check_argument_types()
self.token_list = self.load_token(token_path)
self.unk_symbol = unk_symbol
# self.token_list = self.load_token(token_path)
self.token_list = token_list
self.unk_symbol = token_list[-1]
@staticmethod
def load_token(file_path: Union[Path, str]) -> List:
if not Path(file_path).exists():
raise TokenIDConverterError(f'The {file_path} does not exist.')
with open(str(file_path), 'rb') as f:
token_list = pickle.load(f)
if len(token_list) != len(set(token_list)):
raise TokenIDConverterError('The Token exists duplicated symbol.')
return token_list
# @staticmethod
# def load_token(file_path: Union[Path, str]) -> List:
# if not Path(file_path).exists():
# raise TokenIDConverterError(f'The {file_path} does not exist.')
#
# with open(str(file_path), 'rb') as f:
# token_list = pickle.load(f)
#
# if len(token_list) != len(set(token_list)):
# raise TokenIDConverterError('The Token exists duplicated symbol.')
# return token_list
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
@ -268,31 +270,36 @@ class ONNXRuntimeError(Exception):
class OrtInferSession():
def __init__(self, config):
def __init__(self, model_file, device_id=-1):
sess_opt = SessionOptions()
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
cuda_ep = 'CUDAExecutionProvider'
cuda_provider_options = {
"device_id": device_id,
"arena_extend_strategy": "kNextPowerOfTwo",
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": "true",
}
cpu_ep = 'CPUExecutionProvider'
cpu_provider_options = {
"arena_extend_strategy": "kSameAsRequested",
}
EP_list = []
if config['use_cuda'] and get_device() == 'GPU' \
if device_id != -1 and get_device() == 'GPU' \
and cuda_ep in get_available_providers():
EP_list = [(cuda_ep, config[cuda_ep])]
EP_list = [(cuda_ep, cuda_provider_options)]
EP_list.append((cpu_ep, cpu_provider_options))
config['model_path'] = config['model_path']
self._verify_model(config['model_path'])
self.session = InferenceSession(config['model_path'],
self._verify_model(model_file)
self.session = InferenceSession(model_file,
sess_options=sess_opt,
providers=EP_list)
if config['use_cuda'] and cuda_ep not in self.session.get_providers():
if device_id != -1 and cuda_ep not in self.session.get_providers():
warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
'you can check their relations from the offical web site: '

View File

@ -18,6 +18,7 @@ WavFrontend:
lfr_m: 7
lfr_n: 6
filter_length_max: -.inf
dither: 0.0
Model:
model_path: resources/models/model.onnx