Merge pull request #189 from yuekaizhang/token_list

[Triton] Read token list from config.yaml
This commit is contained in:
zhifu gao 2023-03-06 17:21:40 +08:00 committed by GitHub
commit 659ad8f48b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 21 additions and 26 deletions

View File

@ -8,8 +8,8 @@ git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-1
pretrained_model_dir=$(pwd)/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
cp $pretrained_model_dir/tokens.txt ./model_repo_paraformer_large_offline/scoring/
cp $pretrained_model_dir/am.mvn ./model_repo_paraformer_large_offline/feature_extractor/
cp $pretrained_model_dir/config.yaml ./model_repo_paraformer_large_offline/feature_extractor/
# Refer here to get model.onnx (https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/export/README.md)
cp <exported_onnx_dir>/model.onnx ./model_repo_paraformer_large_offline/encoder/1/
@ -33,10 +33,9 @@ model_repo_paraformer_large_offline/
`-- scoring
|-- 1
| `-- model.py
|-- config.pbtxt
`-- tokens.txt
`-- config.pbtxt
8 directories, 10 files
8 directories, 9 files
```
2. Follow below instructions to launch triton server

View File

@ -229,22 +229,24 @@ class TritonPythonModel:
if key == "config_path":
with open(str(value), 'rb') as f:
config = yaml.load(f, Loader=yaml.Loader)
if key == "cmvn_path":
cmvn_path = str(value)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 1.0 # TODO: 0.0 or 1.0
opts.frame_opts.window_type = config['WavFrontend']['frontend_conf']['window']
opts.mel_opts.num_bins = int(config['WavFrontend']['frontend_conf']['n_mels'])
opts.frame_opts.frame_shift_ms = float(config['WavFrontend']['frontend_conf']['frame_shift'])
opts.frame_opts.frame_length_ms = float(config['WavFrontend']['frontend_conf']['frame_length'])
opts.frame_opts.samp_freq = int(config['WavFrontend']['frontend_conf']['fs'])
opts.frame_opts.window_type = config['frontend_conf']['window']
opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
opts.device = torch.device(self.device)
self.opts = opts
self.feature_extractor = Fbank(self.opts)
self.feature_size = opts.mel_opts.num_bins
self.frontend = WavFrontend(
cmvn_file=config['WavFrontend']['cmvn_file'],
**config['WavFrontend']['frontend_conf'])
cmvn_file=cmvn_path,
**config['frontend_conf'])
def extract_feat(self,
waveform_list: List[np.ndarray]

View File

@ -33,6 +33,10 @@ parameters [
key: "sample_rate"
value: { string_value: "16000"}
},
{
key: "cmvn_path"
value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/am.mvn"}
},
{
key: "config_path"
value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/config.yaml"}

View File

@ -1,11 +0,0 @@
WavFrontend:
cmvn_file: ./model_repo_paraformer_large_offline/feature_extractor/am.mvn
frontend_conf:
fs: 16000
window: hamming
n_mels: 80
frame_length: 25
frame_shift: 10
lfr_m: 7
lfr_n: 6
filter_length_max: -.inf

View File

@ -21,6 +21,7 @@ from torch.utils.dlpack import from_dlpack
import json
import os
import yaml
class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
@ -73,9 +74,9 @@ class TritonPythonModel:
"""
load lang_char.txt
"""
with open(str(vocab_file), 'r') as f:
token_list = [line.strip() for line in f]
return token_list
with open(str(vocab_file), 'rb') as f:
config = yaml.load(f, Loader=yaml.Loader)
return config['token_list']
def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute`

View File

@ -23,7 +23,7 @@ parameters [
},
{
key: "vocabulary",
value: { string_value: "./model_repo_paraformer_large_offline/scoring/tokens.txt"}
value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/config.yaml"}
},
{
key: "lm_path"