update token list

This commit is contained in:
Yuekai Zhang 2023-03-06 16:48:02 +08:00
parent 9b9c948b3c
commit 80e6c258cf
6 changed files with 19 additions and 22 deletions

View File

@ -10,6 +10,7 @@ pretrained_model_dir=$(pwd)/speech_paraformer-large_asr_nat-zh-cn-16k-common-voc
cp $pretrained_model_dir/tokens.txt ./model_repo_paraformer_large_offline/scoring/
cp $pretrained_model_dir/am.mvn ./model_repo_paraformer_large_offline/feature_extractor/
cp $pretrained_model_dir/config.yaml ./model_repo_paraformer_large_offline/feature_extractor/
# Refer here to get model.onnx (https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/export/README.md)
cp <exported_onnx_dir>/model.onnx ./model_repo_paraformer_large_offline/encoder/1/

View File

@ -229,22 +229,24 @@ class TritonPythonModel:
if key == "config_path":
with open(str(value), 'rb') as f:
config = yaml.load(f, Loader=yaml.Loader)
if key == "cmvn_path":
cmvn_path = str(value)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 1.0 # TODO: 0.0 or 1.0
opts.frame_opts.window_type = config['WavFrontend']['frontend_conf']['window']
opts.mel_opts.num_bins = int(config['WavFrontend']['frontend_conf']['n_mels'])
opts.frame_opts.frame_shift_ms = float(config['WavFrontend']['frontend_conf']['frame_shift'])
opts.frame_opts.frame_length_ms = float(config['WavFrontend']['frontend_conf']['frame_length'])
opts.frame_opts.samp_freq = int(config['WavFrontend']['frontend_conf']['fs'])
opts.frame_opts.window_type = config['frontend_conf']['window']
opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
opts.device = torch.device(self.device)
self.opts = opts
self.feature_extractor = Fbank(self.opts)
self.feature_size = opts.mel_opts.num_bins
self.frontend = WavFrontend(
cmvn_file=config['WavFrontend']['cmvn_file'],
**config['WavFrontend']['frontend_conf'])
cmvn_file=cmvn_path,
**config['frontend_conf'])
def extract_feat(self,
waveform_list: List[np.ndarray]

View File

@ -33,6 +33,10 @@ parameters [
key: "sample_rate"
value: { string_value: "16000"}
},
{
key: "cmvn_path"
value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/am.mvn"}
},
{
key: "config_path"
value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/config.yaml"}

View File

@ -1,11 +0,0 @@
WavFrontend:
cmvn_file: ./model_repo_paraformer_large_offline/feature_extractor/am.mvn
frontend_conf:
fs: 16000
window: hamming
n_mels: 80
frame_length: 25
frame_shift: 10
lfr_m: 7
lfr_n: 6
filter_length_max: -.inf

View File

@ -21,6 +21,7 @@ from torch.utils.dlpack import from_dlpack
import json
import os
import yaml
class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
@ -73,9 +74,9 @@ class TritonPythonModel:
"""
load lang_char.txt
"""
with open(str(vocab_file), 'r') as f:
token_list = [line.strip() for line in f]
return token_list
with open(str(vocab_file), 'rb') as f:
config = yaml.load(f, Loader=yaml.Loader)
return config['token_list']
def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute`

View File

@ -23,7 +23,7 @@ parameters [
},
{
key: "vocabulary",
value: { string_value: "./model_repo_paraformer_large_offline/scoring/tokens.txt"}
value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/config.yaml"}
},
{
key: "lm_path"