Merge branch 'dev_infer' of https://github.com/alibaba/FunASR into dev_infer

This commit is contained in:
嘉渊 2023-05-17 18:47:53 +08:00
commit 334dec5d18
22 changed files with 343 additions and 338 deletions

View File

@ -68,10 +68,12 @@ Overview
./runtime/onnxruntime_python.md
./runtime/onnxruntime_cpp.md
./runtime/libtorch_python.md
./runtime/grpc_python.md
./runtime/grpc_cpp.md
./runtime/html5.md
./runtime/websocket_python.md
./runtime/websocket_cpp.md
./runtime/grpc_python.md
./runtime/grpc_cpp.md
.. toctree::
:maxdepth: 1

1
docs/runtime/html5.md Symbolic link
View File

@ -0,0 +1 @@
../../funasr/runtime/html5/readme.md

View File

@ -9,7 +9,7 @@ logger.setLevel(logging.CRITICAL)
inference_pipeline = pipeline(
task=Tasks.punctuation,
model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
output_dir="./tmp/"
model_revision = 'v1.0.2'
)
##################text二进制数据#####################

View File

@ -762,23 +762,6 @@ class Speech2TextParaformerOnline:
feats_len = speech_lengths
if feats.shape[1] != 0:
if cache_en["is_final"]:
if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]:
cache_en["last_chunk"] = True
else:
# first chunk
feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :]
feats_len = torch.tensor([feats_chunk1.shape[1]])
results_chunk1 = self.infer(feats_chunk1, feats_len, cache)
# last chunk
cache_en["last_chunk"] = True
feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :]
feats_len = torch.tensor([feats_chunk2.shape[1]])
results_chunk2 = self.infer(feats_chunk2, feats_len, cache)
return [" ".join(results_chunk1 + results_chunk2)]
results = self.infer(feats, feats_len, cache)
return results

View File

@ -36,6 +36,8 @@ def main(args=None, cmd=None):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
if args.mode == "uniasr":
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
if args.mode == "rnnt":
from funasr.tasks.asr import ASRTransducerTask as ASRTask
ASRTask.main(args=args, cmd=cmd)

View File

@ -19,12 +19,15 @@ from funasr.models.decoder.transformer_decoder import (
)
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ASRModel
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
@ -97,6 +100,7 @@ encoder_choices = ClassChoices(
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
chunk_conformer=ConformerChunkEncoder,
),
default="rnn",
)
@ -171,6 +175,23 @@ stride_conv_choices = ClassChoices(
default="stride_conv1d",
optional=True,
)
rnnt_decoder_choices = ClassChoices(
name="rnnt_decoder",
classes=dict(
rnnt=RNNTDecoder,
),
default="rnnt",
optional=True,
)
joint_network_choices = ClassChoices(
name="joint_network",
classes=dict(
joint_network=JointNetwork,
),
default="joint_network",
optional=True,
)
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
@ -194,6 +215,10 @@ class_choices_list = [
predictor_choices2,
# --stride_conv and --stride_conv_conf
stride_conv_choices,
# --rnnt_decoder and --rnnt_decoder_conf
rnnt_decoder_choices,
# --joint_network and --joint_network_conf
joint_network_choices,
]
@ -342,6 +367,63 @@ def build_asr_model(args):
token_list=token_list,
**args.model_conf,
)
elif args.model == "rnnt":
# 5. Decoder
encoder_output_size = encoder.output_size()
rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
decoder = rnnt_decoder_class(
vocab_size,
**args.rnnt_decoder_conf,
)
decoder_output_size = decoder.output_size
if getattr(args, "decoder", None) is not None:
att_decoder_class = decoder_choices.get_class(args.decoder)
att_decoder = att_decoder_class(
vocab_size=vocab_size,
encoder_output_size=encoder_output_size,
**args.decoder_conf,
)
else:
att_decoder = None
# 6. Joint Network
joint_network = JointNetwork(
vocab_size,
encoder_output_size,
decoder_output_size,
**args.joint_network_conf,
)
# 7. Build model
if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
model = UnifiedTransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
encoder=encoder,
decoder=decoder,
att_decoder=att_decoder,
joint_network=joint_network,
**args.model_conf,
)
else:
model = TransducerModel(
vocab_size=vocab_size,
token_list=token_list,
frontend=frontend,
specaug=specaug,
normalize=normalize,
encoder=encoder,
decoder=decoder,
att_decoder=att_decoder,
joint_network=joint_network,
**args.model_conf,
)
else:
raise NotImplementedError("Not supported model: {}".format(args.model))
@ -349,4 +431,4 @@ def build_asr_model(args):
if args.init is not None:
initialize(model, args.init)
return model
return model

View File

@ -12,7 +12,7 @@ if __name__ == '__main__':
return {'inputs': np.ones((1, text_length), dtype=np.int64),
'text_lengths': np.array([text_length,], dtype=np.int32),
'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
}
def _run(feed_dict):

View File

@ -1078,7 +1078,7 @@ class ConformerChunkEncoder(AbsEncoder):
limit_size,
)
mask = make_source_mask(x_len)
mask = make_source_mask(x_len).to(x.device)
if self.unified_model_training:
chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()

View File

@ -355,18 +355,9 @@ class SANMEncoder(AbsEncoder):
def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
if len(cache) == 0:
return feats
# process last chunk
cache["feats"] = to_device(cache["feats"], device=feats.device)
overlap_feats = torch.cat((cache["feats"], feats), dim=1)
if cache["is_final"]:
cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :]
if not cache["last_chunk"]:
padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1]
overlap_feats = overlap_feats.transpose(1, 2)
overlap_feats = F.pad(overlap_feats, (0, padding_length))
overlap_feats = overlap_feats.transpose(1, 2)
else:
cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
return overlap_feats
def forward_chunk(self,

View File

@ -221,13 +221,14 @@ class CifPredictorV2(nn.Module):
if cache is not None and "chunk_size" in cache:
alphas[:, :cache["chunk_size"][0]] = 0.0
alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
if "is_final" in cache and not cache["is_final"]:
alphas[:, sum(cache["chunk_size"][:2]):] = 0.0
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device)
cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device)
hidden = torch.cat((cache["cif_hidden"], hidden), dim=1)
alphas = torch.cat((cache["cif_alphas"], alphas), dim=1)
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
if cache is not None and "is_final" in cache and cache["is_final"]:
tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device)
tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device)
tail_alphas = torch.tile(tail_alphas, (batch_size, 1))

View File

@ -9,70 +9,70 @@ pyOpenSSL
```
### javascript
[html5录音](https://github.com/xiangyuecn/Recorder)
[html5 recorder.js](https://github.com/xiangyuecn/Recorder)
```shell
Recorder
```
### demo页面如下
![img](https://github.com/alibaba-damo-academy/FunASR/blob/for-html5-demo/funasr/runtime/html5/demo.gif)
### demo
![img](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/html5/demo.gif)
## 两种ws_server_online连接模式
### 1)直接连接模式浏览器https麦克风 --> html5 demo服务 --> js wss接口 --> wss asr online srv(证书生成请往后看)
## wss or ws protocol for ws_server_online
1) wss: browser microphone data --> html5 demo server --> js wss api --> wss asr online srv #for certificate generation just look back
### 2)nginx中转浏览器https麦克风 --> html5 demo服务 --> js wss接口 --> nginx服务 --> ws asr online srv
2) ws: browser microphone data --> html5 demo server --> js wss api --> nginx wss server --> ws asr online srv
## 1.html5 demo服务启动
### 启动html5服务需要ssl证书(自己生成请往后看)
## 1.html5 demo start
### ssl certificate is required
```shell
usage: h5Server.py [-h] [--host HOST] [--port PORT] [--certfile CERTFILE]
[--keyfile KEYFILE]
python h5Server.py --port 1337
```
## 2.启动ws or wss asr online srv
[具体请看online asr](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket)
online asr提供两种ws和wss模式wss模式可以直接启动无需nginx中转。否则需要通过nginx将wss转发到该online asr的ws端口上
### wss方式
## 2.asr online srv start
[detail for online asr](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket)
Online asr provides wss or ws way. if started in ws way, nginx is required for relay.
### wss way, ssl certificate is required
```shell
python ws_server_online.py --certfile server.crt --keyfile server.key --port 5921
```
### ws方式
### ws way
```shell
python ws_server_online.py --port 5921
```
## 3.修改wsconnecter.js里asr接口地址
wsconnecter.js里配置online asr服务地址路径这里配置的是wss端口
## 3.modify asr address in wsconnecter.js according to your environment
asr address in wsconnecter.js must be wss, just like
var Uri = "wss://xxx:xxx/"
## 4.浏览器打开地址测试
https://127.0.0.1:1337/static/index.html
## 4.open browser to access html5 demo
https://youraddress:port/static/index.html
## 自行生成证书
生成证书(注意这种证书并不能被所有浏览器认可,部分手动授权可以访问,最好使用其他认证的官方ssl证书)
## certificate generation by yourself
generated certificate may not suitable for all browsers due to security concerns. you'd better buy or download an authenticated ssl certificate from authorized agency.
```shell
### 1)生成私钥,按照提示填写内容
### 1) Generate a private key
openssl genrsa -des3 -out server.key 1024
### 2)生成csr文件 ,按照提示填写内容
### 2) Generate a csr file
openssl req -new -key server.key -out server.csr
### 去掉pass
### 3) Remove pass
cp server.key server.key.org
openssl rsa -in server.key.org -out server.key
### 生成crt文件有效期1年365天
### 4) Generated a crt file, valid for 1 year
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
```
## nginx配置说明(了解的可以跳过)
h5打开麦克风需要https协议同时后端的asr websocket也必须是wss协议如果[online asr](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket)以ws方式运行我们可以通过nginx配置实现wss协议到ws协议的转换。
### nginx转发配置示例
## nginx configuration (you can skip it if you known)
https and wss protocol are required by browsers when want to open microphone and websocket.
if [online asr](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket) run in ws way, you should use nginx to convert wss to ws.
### nginx wss->ws configuration example
```shell
events { [0/1548]
worker_connections 1024;

View File

@ -0,0 +1,111 @@
# online asr demo for html5
## requirement
### python
```shell
flask
gevent
pyOpenSSL
```
### javascript
[html5录音](https://github.com/xiangyuecn/Recorder)
```shell
Recorder
```
### demo页面如下
![img](https://github.com/alibaba-damo-academy/FunASR/blob/for-html5-demo/funasr/runtime/html5/demo.gif)
## 两种ws_server_online连接模式
### 1)直接连接模式浏览器https麦克风 --> html5 demo服务 --> js wss接口 --> wss asr online srv(证书生成请往后看)
### 2)nginx中转浏览器https麦克风 --> html5 demo服务 --> js wss接口 --> nginx服务 --> ws asr online srv
## 1.html5 demo服务启动
### 启动html5服务需要ssl证书(自己生成请往后看)
```shell
usage: h5Server.py [-h] [--host HOST] [--port PORT] [--certfile CERTFILE]
[--keyfile KEYFILE]
python h5Server.py --port 1337
```
## 2.启动ws or wss asr online srv
[具体请看online asr](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket)
online asr提供两种ws和wss模式wss模式可以直接启动无需nginx中转。否则需要通过nginx将wss转发到该online asr的ws端口上
### wss方式
```shell
python ws_server_online.py --certfile server.crt --keyfile server.key --port 5921
```
### ws方式
```shell
python ws_server_online.py --port 5921
```
## 3.修改wsconnecter.js里asr接口地址
wsconnecter.js里配置online asr服务地址路径这里配置的是wss端口
var Uri = "wss://xxx:xxx/"
## 4.浏览器打开地址测试
https://127.0.0.1:1337/static/index.html
## 自行生成证书
生成证书(注意这种证书并不能被所有浏览器认可,部分手动授权可以访问,最好使用其他认证的官方ssl证书)
```shell
### 1)生成私钥,按照提示填写内容
openssl genrsa -des3 -out server.key 1024
### 2)生成csr文件 ,按照提示填写内容
openssl req -new -key server.key -out server.csr
### 去掉pass
cp server.key server.key.org
openssl rsa -in server.key.org -out server.key
### 生成crt文件有效期1年365天
openssl x509 -req -days 365 -in server.csr -signkey server.key -out server.crt
```
## nginx配置说明(了解的可以跳过)
h5打开麦克风需要https协议同时后端的asr websocket也必须是wss协议如果[online asr](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/websocket)以ws方式运行我们可以通过nginx配置实现wss协议到ws协议的转换。
### nginx转发配置示例
```shell
events { [0/1548]
worker_connections 1024;
accept_mutex on;
}
http {
error_log error.log;
access_log access.log;
server {
listen 5921 ssl http2; # nginx listen port for wss
server_name www.test.com;
ssl_certificate /funasr/server.crt;
ssl_certificate_key /funasr/server.key;
ssl_protocols TLSv1 TLSv1.1 TLSv1.2;
ssl_ciphers HIGH:!aNULL:!MD5;
location /wss/ {
proxy_pass http://127.0.0.1:1111/; # asr online model ws address and port
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_read_timeout 600s;
}
}
```
### 修改wsconnecter.js里asr接口地址
wsconnecter.js里配置online asr服务地址路径这里配置的是wss端口
var Uri = "wss://xxx:xxx/wss/"
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
2. We acknowledge [AiHealthx](http://www.aihealthx.com/) for contributing the html5 demo.

View File

@ -5,7 +5,7 @@
/* 2021-2023 by zhaoming,mali aihealthx.com */
function WebSocketConnectMethod( config ) { //定义socket连接方法类
var Uri = "wss://111.205.137.58:5821/wss/" //设置wss asr online接口地址 如 wss://X.X.X.X:port/wss/
var Uri = "wss://30.220.136.139:5921/" // var Uri = "wss://30.221.177.46:5921/" //设置wss asr online接口地址 如 wss://X.X.X.X:port/wss/
var speechSokt;
var connKeeperID;

View File

@ -11,15 +11,11 @@ class VadModel {
public:
virtual ~VadModel(){};
virtual void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num)=0;
virtual std::vector<std::vector<int>> Infer(const std::vector<float> &waves)=0;
virtual std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true)=0;
virtual void ReadModel(const char* vad_model)=0;
virtual void LoadConfigFromYaml(const char* filename)=0;
virtual void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
const std::vector<float> &waves)=0;
virtual void LfrCmvn(std::vector<std::vector<float>> &vad_feats)=0;
virtual void Forward(
const std::vector<std::vector<float>> &chunk_feats,
std::vector<std::vector<float>> *out_prob)=0;
std::vector<float> &waves)=0;
virtual void LoadCmvn(const char *filename)=0;
virtual void InitCache()=0;
};

View File

@ -127,6 +127,8 @@ For example:
### funasr-onnx-offline-rtf
```shell
./funasr-onnx-offline-rtf --model-dir <string> [--quantize <string>]
[--vad-dir <string>] [--vad-quant <string>]
[--punc-dir <string>] [--punc-quant <string>]
--wav-path <string> --thread-num <int32_t>
[--] [--version] [-h]
Where:
@ -136,6 +138,17 @@ Where:
(required) the model path, which contains model.onnx, config.yaml, am.mvn
--quantize <string>
false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
--vad-dir <string>
the vad model path, which contains model.onnx, vad.yaml, vad.mvn
--vad-quant <string>
false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
--punc-dir <string>
the punc model path, which contains model.onnx, punc.yaml
--punc-quant <string>
false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
--wav-path <string>
(required) the input could be:
wav_path, e.g.: asr_example.wav;

View File

@ -162,17 +162,21 @@ void FsmnVad::Forward(
}
// get 4 caches outputs,each size is 128*19
for (int i = 1; i < 5; i++) {
float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
}
// for (int i = 1; i < 5; i++) {
// float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
// memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
// }
}
void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
const std::vector<float> &waves) {
std::vector<float> &waves) {
knf::OnlineFbank fbank(fbank_opts);
fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
std::vector<float> buf(waves.size());
for (int32_t i = 0; i != waves.size(); ++i) {
buf[i] = waves[i] * 32768;
}
fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
int32_t frames = fbank.NumFramesReady();
for (int32_t i = 0; i != frames; ++i) {
const float *frame = fbank.GetFrame(i);
@ -267,7 +271,7 @@ void FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
}
std::vector<std::vector<int>>
FsmnVad::Infer(const std::vector<float> &waves) {
FsmnVad::Infer(std::vector<float> &waves, bool input_finished) {
std::vector<std::vector<float>> vad_feats;
std::vector<std::vector<float>> vad_probs;
FbankKaldi(vad_sample_rate_, vad_feats, waves);

View File

@ -21,7 +21,7 @@ public:
~FsmnVad();
void Test();
void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num);
std::vector<std::vector<int>> Infer(const std::vector<float> &waves);
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished=true);
void Reset();
private:
@ -34,7 +34,7 @@ private:
std::vector<const char *> *in_names, std::vector<const char *> *out_names);
void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
const std::vector<float> &waves);
std::vector<float> &waves);
void LfrCmvn(std::vector<std::vector<float>> &vad_feats);

View File

@ -39,7 +39,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list,
// warm up
for (size_t i = 0; i < 1; i++)
{
FUNASR_RESULT result=FunASRInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, 16000);
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, 16000);
}
while (true) {
@ -50,7 +50,7 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list,
}
gettimeofday(&start, NULL);
FUNASR_RESULT result=FunASRInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, 16000);
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, 16000);
gettimeofday(&end, NULL);
seconds = (end.tv_sec - start.tv_sec);
@ -102,12 +102,20 @@ int main(int argc, char *argv[])
TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
cmd.add(model_dir);
cmd.add(quantize);
cmd.add(vad_dir);
cmd.add(vad_quant);
cmd.add(punc_dir);
cmd.add(punc_quant);
cmd.add(wav_path);
cmd.add(thread_num);
cmd.parse(argc, argv);
@ -115,11 +123,15 @@ int main(int argc, char *argv[])
std::map<std::string, std::string> model_path;
GetValue(model_dir, MODEL_DIR, model_path);
GetValue(quantize, QUANTIZE, model_path);
GetValue(vad_dir, VAD_DIR, model_path);
GetValue(vad_quant, VAD_QUANT, model_path);
GetValue(punc_dir, PUNC_DIR, model_path);
GetValue(punc_quant, PUNC_QUANT, model_path);
GetValue(wav_path, WAV_PATH, model_path);
struct timeval start, end;
gettimeofday(&start, NULL);
FUNASR_HANDLE asr_handle=FunASRInit(model_path, 1);
FUNASR_HANDLE asr_handle=FunOfflineInit(model_path, 1);
if (!asr_handle)
{
@ -132,7 +144,7 @@ int main(int argc, char *argv[])
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
// read wav_scp
// read wav_path
vector<string> wav_list;
string wav_path_ = model_path.at(WAV_PATH);
if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
@ -179,6 +191,6 @@ int main(int argc, char *argv[])
LOG(INFO) << "total_rtf " << (double)total_time/ (total_length*1000000);
LOG(INFO) << "speedup " << 1.0/((double)total_time/ (total_length*1000000));
FunASRUninit(asr_handle);
FunOfflineUninit(asr_handle);
return 0;
}

View File

@ -69,7 +69,11 @@ void Paraformer::Reset()
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
knf::OnlineFbank fbank_(fbank_opts);
fbank_.AcceptWaveform(sample_rate, waves, len);
std::vector<float> buf(len);
for (int32_t i = 0; i != len; ++i) {
buf[i] = waves[i] * 32768;
}
fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
//fbank_->InputFinished();
int32_t frames = fbank_.NumFramesReady();
int32_t feature_dim = fbank_opts.mel_opts.num_bins;

View File

@ -186,11 +186,12 @@ class CT_Transformer_VadRealtime(CT_Transformer):
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
text_length = len(mini_sentence_id)
vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
data = {
"input": mini_sentence_id[None,:],
"text_lengths": np.array([text_length], dtype='int32'),
"vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
"sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
"vad_mask": vad_mask,
"sub_masks": vad_mask
}
try:
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])

View File

@ -32,15 +32,29 @@ inference_pipeline_asr_online = pipeline(
ncpu=args.ncpu,
model_revision='v1.0.4')
# vad
inference_pipeline_vad = pipeline(
task=Tasks.voice_activity_detection,
model=args.vad_model,
model_revision=None,
output_dir=None,
batch_size=1,
mode='online',
ngpu=args.ngpu,
ncpu=1,
)
print("model loaded")
async def ws_serve(websocket, path):
frames = []
frames_asr_online = []
global websocket_users
websocket_users.add(websocket)
websocket.param_dict_asr_online = {"cache": dict()}
websocket.param_dict_vad = {'in_cache': dict()}
websocket.wav_name = "microphone"
print("new user connected",flush=True)
try:
@ -53,9 +67,10 @@ async def ws_serve(websocket, path):
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
websocket.param_dict_vad["is_final"] = not websocket.is_speaking
# need to fire engine manually if no data received any more
if not websocket.is_speaking:
await async_asr_online(websocket,b"")
await async_asr_online(websocket, b"")
if "chunk_interval" in messagejson:
websocket.chunk_interval=messagejson["chunk_interval"]
if "wav_name" in messagejson:
@ -64,14 +79,18 @@ async def ws_serve(websocket, path):
websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
# if has bytes in buffer or message is bytes
if len(frames_asr_online) > 0 or not isinstance(message, str):
if not isinstance(message,str):
if not isinstance(message, str):
frames_asr_online.append(message)
# frames.append(message)
# duration_ms = len(message) // 32
# websocket.vad_pre_idx += duration_ms
speech_start_i, speech_end_i = await async_vad(websocket, message)
websocket.is_speaking = not speech_end_i
if len(frames_asr_online) % websocket.chunk_interval == 0 or not websocket.is_speaking:
websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
audio_in = b"".join(frames_asr_online)
# if not websocket.is_speaking:
#padding 0.5s at end gurantee that asr engine can fire out last word
# audio_in=audio_in+b''.join(np.zeros(int(16000*0.5),dtype=np.int16))
await async_asr_online(websocket,audio_in)
await async_asr_online(websocket, audio_in)
frames_asr_online = []
@ -85,7 +104,7 @@ async def ws_serve(websocket, path):
async def async_asr_online(websocket,audio_in):
if len(audio_in) >=0:
if len(audio_in) >= 0:
audio_in = load_bytes(audio_in)
rec_result = inference_pipeline_asr_online(audio_in=audio_in,
param_dict=websocket.param_dict_asr_online)
@ -97,16 +116,30 @@ async def async_asr_online(websocket,audio_in):
await websocket.send(message)
async def async_vad(websocket, audio_in):
segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
speech_start = False
speech_end = False
if len(segments_result) == 0 or len(segments_result["text"]) > 1:
return speech_start, speech_end
if segments_result["text"][0][0] != -1:
speech_start = segments_result["text"][0][0]
if segments_result["text"][0][1] != -1:
speech_end = True
return speech_start, speech_end
if len(args.certfile)>0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
ssl_cert = args.certfile
ssl_key = args.keyfile
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
ssl_cert = args.certfile
ssl_key = args.keyfile
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
else:
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()

View File

@ -290,6 +290,8 @@ class ASRTask(AbsTask):
predictor_choices2,
# --stride_conv and --stride_conv_conf
stride_conv_choices,
# --rnnt_decoder and --rnnt_decoder_conf
rnnt_decoder_choices,
]
# If you need to modify train() or eval() procedures, change Trainer class here
@ -1360,7 +1362,7 @@ class ASRTaskAligner(ASRTaskParaformer):
return retval
class ASRTransducerTask(AbsTask):
class ASRTransducerTask(ASRTask):
"""ASR Transducer Task definition."""
num_optimizers: int = 1
@ -1371,244 +1373,11 @@ class ASRTransducerTask(AbsTask):
normalize_choices,
encoder_choices,
rnnt_decoder_choices,
joint_network_choices,
]
trainer = Trainer
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
"""Add Transducer task arguments.
Args:
cls: ASRTransducerTask object.
parser: Transducer arguments parser.
"""
group = parser.add_argument_group(description="Task related.")
# required = parser.get_default("required")
# required += ["token_list"]
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="Integer-string mapper for tokens.",
)
group.add_argument(
"--split_with_space",
type=str2bool,
default=True,
help="whether to split text using <space>",
)
group.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of dimensions for input features.",
)
group.add_argument(
"--init",
type=str_or_none,
default=None,
help="Type of model initialization to use.",
)
group.add_argument(
"--model_conf",
action=NestedDictAction,
default=get_default_kwargs(TransducerModel),
help="The keyword arguments for the model class.",
)
# group.add_argument(
# "--encoder_conf",
# action=NestedDictAction,
# default={},
# help="The keyword arguments for the encoder class.",
# )
group.add_argument(
"--joint_network_conf",
action=NestedDictAction,
default={},
help="The keyword arguments for the joint network class.",
)
group = parser.add_argument_group(description="Preprocess related.")
group.add_argument(
"--use_preprocessor",
type=str2bool,
default=True,
help="Whether to apply preprocessing to input data.",
)
group.add_argument(
"--token_type",
type=str,
default="bpe",
choices=["bpe", "char", "word", "phn"],
help="The type of tokens to use during tokenization.",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The path of the sentencepiece model.",
)
parser.add_argument(
"--non_linguistic_symbols",
type=str_or_none,
help="The 'non_linguistic_symbols' file path.",
)
parser.add_argument(
"--cleaner",
type=str_or_none,
choices=[None, "tacotron", "jaconv", "vietnamese"],
default=None,
help="Text cleaner to use.",
)
parser.add_argument(
"--g2p",
type=str_or_none,
choices=g2p_choices,
default=None,
help="g2p method to use if --token_type=phn.",
)
parser.add_argument(
"--speech_volume_normalize",
type=float_or_none,
default=None,
help="Normalization value for maximum amplitude scaling.",
)
parser.add_argument(
"--rir_scp",
type=str_or_none,
default=None,
help="The RIR SCP file path.",
)
parser.add_argument(
"--rir_apply_prob",
type=float,
default=1.0,
help="The probability of the applied RIR convolution.",
)
parser.add_argument(
"--noise_scp",
type=str_or_none,
default=None,
help="The path of noise SCP file.",
)
parser.add_argument(
"--noise_apply_prob",
type=float,
default=1.0,
help="The probability of the applied noise addition.",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of the noise decibel level.",
)
for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --decoder and --decoder_conf
class_choices.add_arguments(group)
@classmethod
def build_collate_fn(
cls, args: argparse.Namespace, train: bool
) -> Callable[
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
Tuple[List[str], Dict[str, torch.Tensor]],
]:
"""Build collate function.
Args:
cls: ASRTransducerTask object.
args: Task arguments.
train: Training mode.
Return:
: Callable collate function.
"""
assert check_argument_types()
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@classmethod
def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
"""Build pre-processing function.
Args:
cls: ASRTransducerTask object.
args: Task arguments.
train: Training mode.
Return:
: Callable pre-processing function.
"""
assert check_argument_types()
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
rir_apply_prob=args.rir_apply_prob
if hasattr(args, "rir_apply_prob")
else 1.0,
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
noise_apply_prob=args.noise_apply_prob
if hasattr(args, "noise_apply_prob")
else 1.0,
noise_db_range=args.noise_db_range
if hasattr(args, "noise_db_range")
else "13_15",
speech_volume_normalize=args.speech_volume_normalize
if hasattr(args, "rir_scp")
else None,
)
else:
retval = None
assert check_return_type(retval)
return retval
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
"""Required data depending on task mode.
Args:
cls: ASRTransducerTask object.
train: Training mode.
inference: Inference mode.
Return:
retval: Required task data.
"""
if not inference:
retval = ("speech", "text")
else:
retval = ("speech",)
return retval
@classmethod
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
"""Optional data depending on task mode.
Args:
cls: ASRTransducerTask object.
train: Training mode.
inference: Inference mode.
Return:
retval: Optional task data.
"""
retval = ()
assert check_return_type(retval)
return retval
@classmethod
def build_model(cls, args: argparse.Namespace) -> TransducerModel:
"""Required data depending on task mode.