mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update online runtime, including vad-online, paraformer-online, punc-online,2pass (#815)
* init * update * add LoadConfigFromYaml * update * update * update * del time stat * update * update * update * update * update * update * update * add cpp websocket online 2pass srv * [feature] multithread grpc server * update * update * update * [feature] support 2pass grpc cpp server and python client, can change mode to use offline, online or 2pass decoding * update * update * update * update * add paraformer online onnx model export * add paraformer online onnx model export * add paraformer online onnx model export * add paraformer online onnxruntime * add paraformer online onnxruntime * add paraformer online onnxruntime * fix export paraformer online onnx model bug * for client closed earlier and core dump * support GRPC two pass decoding (#813) * [refator] optimize grpc server pipeline and instruction * [refator] rm useless file * [refator] optimize grpc client pipeline and instruction * [debug] hanlde coredump when client ternimated * [refator] rm useless log * [refator] modify grpc cmake * Create run_server_2pass.sh * Update SDK_tutorial_online_zh.md * Update SDK_tutorial_online.md * Update SDK_advanced_guide_online.md * Update SDK_advanced_guide_online_zh.md * Update SDK_tutorial_online_zh.md * Update SDK_tutorial_online.md * update --------- Co-authored-by: zhaoming <zhaomingwork@qq.com> Co-authored-by: boji123 <boji123@aliyun.com> Co-authored-by: haoneng.lhn <haoneng.lhn@alibaba-inc.com>
This commit is contained in:
parent
57968c2180
commit
b454a1054f
3
.gitignore
vendored
3
.gitignore
vendored
@ -19,4 +19,5 @@ build
|
||||
funasr.egg-info
|
||||
docs/_build
|
||||
modelscope
|
||||
samples
|
||||
samples
|
||||
.ipynb_checkpoints
|
||||
|
||||
@ -55,18 +55,21 @@ class ModelExport:
|
||||
|
||||
# export encoder1
|
||||
self.export_config["model_name"] = "model"
|
||||
model = get_model(
|
||||
models = get_model(
|
||||
model,
|
||||
self.export_config,
|
||||
)
|
||||
model.eval()
|
||||
# self._export_onnx(model, verbose, export_dir)
|
||||
if self.onnx:
|
||||
self._export_onnx(model, verbose, export_dir)
|
||||
else:
|
||||
self._export_torchscripts(model, verbose, export_dir)
|
||||
|
||||
print("output dir: {}".format(export_dir))
|
||||
if not isinstance(models, tuple):
|
||||
models = (models,)
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.eval()
|
||||
if self.onnx:
|
||||
self._export_onnx(model, verbose, export_dir)
|
||||
else:
|
||||
self._export_torchscripts(model, verbose, export_dir)
|
||||
|
||||
print("output dir: {}".format(export_dir))
|
||||
|
||||
|
||||
def _torch_quantize(self, model):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
|
||||
from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
|
||||
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
|
||||
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
|
||||
from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
|
||||
@ -10,10 +10,15 @@ from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer
|
||||
from funasr.train.abs_model import PunctuationModel
|
||||
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
|
||||
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
|
||||
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export
|
||||
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export
|
||||
|
||||
def get_model(model, export_config=None):
|
||||
if isinstance(model, BiCifParaformer):
|
||||
return BiCifParaformer_export(model, **export_config)
|
||||
elif isinstance(model, ParaformerOnline):
|
||||
return (ParaformerOnline_encoder_predictor_export(model, model_name="model"),
|
||||
ParaformerOnline_decoder_export(model, model_name="decoder"))
|
||||
elif isinstance(model, Paraformer):
|
||||
return Paraformer_export(model, **export_config)
|
||||
elif isinstance(model, Conformer_export):
|
||||
|
||||
@ -157,3 +157,158 @@ class ParaformerSANMDecoder(nn.Module):
|
||||
"n_layers": len(self.model.decoders) + len(self.model.decoders2),
|
||||
"odim": self.model.decoders[0].size
|
||||
}
|
||||
|
||||
|
||||
class ParaformerSANMDecoderOnline(nn.Module):
|
||||
def __init__(self, model,
|
||||
max_seq_len=512,
|
||||
model_name='decoder',
|
||||
onnx: bool = True, ):
|
||||
super().__init__()
|
||||
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
|
||||
self.model = model
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
for i, d in enumerate(self.model.decoders):
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
|
||||
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
|
||||
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
|
||||
if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
|
||||
d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
|
||||
self.model.decoders[i] = DecoderLayerSANM_export(d)
|
||||
|
||||
if self.model.decoders2 is not None:
|
||||
for i, d in enumerate(self.model.decoders2):
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
|
||||
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
|
||||
if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
|
||||
d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
|
||||
self.model.decoders2[i] = DecoderLayerSANM_export(d)
|
||||
|
||||
for i, d in enumerate(self.model.decoders3):
|
||||
if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
|
||||
d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
|
||||
self.model.decoders3[i] = DecoderLayerSANM_export(d)
|
||||
|
||||
self.output_layer = model.output_layer
|
||||
self.after_norm = model.after_norm
|
||||
self.model_name = model_name
|
||||
|
||||
def prepare_mask(self, mask):
|
||||
mask_3d_btd = mask[:, :, None]
|
||||
if len(mask.shape) == 2:
|
||||
mask_4d_bhlt = 1 - mask[:, None, None, :]
|
||||
elif len(mask.shape) == 3:
|
||||
mask_4d_bhlt = 1 - mask[:, None, :]
|
||||
mask_4d_bhlt = mask_4d_bhlt * -10000.0
|
||||
|
||||
return mask_3d_btd, mask_4d_bhlt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hs_pad: torch.Tensor,
|
||||
hlens: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
|
||||
tgt = ys_in_pad
|
||||
tgt_mask = self.make_pad_mask(ys_in_lens)
|
||||
tgt_mask, _ = self.prepare_mask(tgt_mask)
|
||||
# tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
|
||||
|
||||
memory = hs_pad
|
||||
memory_mask = self.make_pad_mask(hlens)
|
||||
_, memory_mask = self.prepare_mask(memory_mask)
|
||||
# memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
|
||||
|
||||
x = tgt
|
||||
out_caches = list()
|
||||
for i, decoder in enumerate(self.model.decoders):
|
||||
in_cache = args[i]
|
||||
x, tgt_mask, memory, memory_mask, out_cache = decoder(
|
||||
x, tgt_mask, memory, memory_mask, cache=in_cache
|
||||
)
|
||||
out_caches.append(out_cache)
|
||||
if self.model.decoders2 is not None:
|
||||
for i, decoder in enumerate(self.model.decoders2):
|
||||
in_cache = args[i+len(self.model.decoders)]
|
||||
x, tgt_mask, memory, memory_mask, out_cache = decoder(
|
||||
x, tgt_mask, memory, memory_mask, cache=in_cache
|
||||
)
|
||||
out_caches.append(out_cache)
|
||||
x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
|
||||
x, tgt_mask, memory, memory_mask
|
||||
)
|
||||
x = self.after_norm(x)
|
||||
x = self.output_layer(x)
|
||||
|
||||
return x, out_caches
|
||||
|
||||
def get_dummy_inputs(self, enc_size):
|
||||
enc = torch.randn(2, 100, enc_size).type(torch.float32)
|
||||
enc_len = torch.tensor([30, 100], dtype=torch.int32)
|
||||
acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
|
||||
acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
|
||||
cache_num = len(self.model.decoders)
|
||||
if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
|
||||
cache_num += len(self.model.decoders2)
|
||||
cache = [
|
||||
torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size-1), dtype=torch.float32)
|
||||
for _ in range(cache_num)
|
||||
]
|
||||
return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
|
||||
|
||||
def get_input_names(self):
|
||||
cache_num = len(self.model.decoders)
|
||||
if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
|
||||
cache_num += len(self.model.decoders2)
|
||||
return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
|
||||
+ ['in_cache_%d' % i for i in range(cache_num)]
|
||||
|
||||
def get_output_names(self):
|
||||
cache_num = len(self.model.decoders)
|
||||
if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
|
||||
cache_num += len(self.model.decoders2)
|
||||
return ['logits', 'sample_ids'] \
|
||||
+ ['out_cache_%d' % i for i in range(cache_num)]
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
ret = {
|
||||
'enc': {
|
||||
0: 'batch_size',
|
||||
1: 'enc_length'
|
||||
},
|
||||
'acoustic_embeds': {
|
||||
0: 'batch_size',
|
||||
1: 'token_length'
|
||||
},
|
||||
'enc_len': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
'acoustic_embeds_len': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
|
||||
}
|
||||
cache_num = len(self.model.decoders)
|
||||
if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
|
||||
cache_num += len(self.model.decoders2)
|
||||
ret.update({
|
||||
'in_cache_%d' % d: {
|
||||
0: 'batch_size',
|
||||
}
|
||||
for d in range(cache_num)
|
||||
})
|
||||
ret.update({
|
||||
'out_cache_%d' % d: {
|
||||
0: 'batch_size',
|
||||
}
|
||||
for d in range(cache_num)
|
||||
})
|
||||
return ret
|
||||
|
||||
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||
|
||||
from funasr.export.utils.torch_function import MakePadMask
|
||||
from funasr.export.utils.torch_function import sequence_mask
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder
|
||||
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
|
||||
from funasr.models.encoder.conformer_encoder import ConformerEncoder
|
||||
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
|
||||
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
|
||||
@ -15,6 +15,7 @@ from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
|
||||
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
|
||||
from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
|
||||
from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
|
||||
from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoderOnline as ParaformerSANMDecoderOnline_export
|
||||
|
||||
|
||||
class Paraformer(nn.Module):
|
||||
@ -216,4 +217,150 @@ class BiCifParaformer(nn.Module):
|
||||
0: 'batch_size',
|
||||
1: 'alphas_length'
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ParaformerOnline_encoder_predictor(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
feats_dim=560,
|
||||
model_name='model',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
onnx = False
|
||||
if "onnx" in kwargs:
|
||||
onnx = kwargs["onnx"]
|
||||
if isinstance(model.encoder, SANMEncoder) or isinstance(model.encoder, SANMEncoderChunkOpt):
|
||||
self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
|
||||
elif isinstance(model.encoder, ConformerEncoder):
|
||||
self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
|
||||
if isinstance(model.predictor, CifPredictorV2):
|
||||
self.predictor = CifPredictorV2_export(model.predictor)
|
||||
|
||||
self.feats_dim = feats_dim
|
||||
self.model_name = model_name
|
||||
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
):
|
||||
# a. To device
|
||||
batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
|
||||
# batch = to_device(batch, device=self.device)
|
||||
|
||||
enc, enc_len = self.encoder(**batch)
|
||||
mask = self.make_pad_mask(enc_len)[:, None, :]
|
||||
alphas, _ = self.predictor.forward_cnn(enc, mask)
|
||||
|
||||
return enc, enc_len, alphas
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
speech = torch.randn(2, 30, self.feats_dim)
|
||||
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
|
||||
return (speech, speech_lengths)
|
||||
|
||||
def get_input_names(self):
|
||||
return ['speech', 'speech_lengths']
|
||||
|
||||
def get_output_names(self):
|
||||
return ['enc', 'enc_len', 'alphas']
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return {
|
||||
'speech': {
|
||||
0: 'batch_size',
|
||||
1: 'feats_length'
|
||||
},
|
||||
'speech_lengths': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
'enc': {
|
||||
0: 'batch_size',
|
||||
1: 'feats_length'
|
||||
},
|
||||
'enc_len': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
'alphas': {
|
||||
0: 'batch_size',
|
||||
1: 'feats_length'
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ParaformerOnline_decoder(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab, Alibaba Group, China
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
max_seq_len=512,
|
||||
feats_dim=560,
|
||||
model_name='model',
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
onnx = False
|
||||
if "onnx" in kwargs:
|
||||
onnx = kwargs["onnx"]
|
||||
|
||||
if isinstance(model.decoder, ParaformerDecoderSAN):
|
||||
self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
|
||||
elif isinstance(model.decoder, ParaformerSANMDecoder):
|
||||
self.decoder = ParaformerSANMDecoderOnline_export(model.decoder, onnx=onnx)
|
||||
|
||||
self.feats_dim = feats_dim
|
||||
self.model_name = model_name
|
||||
self.enc_size = model.encoder._output_size
|
||||
|
||||
if onnx:
|
||||
self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
|
||||
else:
|
||||
self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
enc: torch.Tensor,
|
||||
enc_len: torch.Tensor,
|
||||
acoustic_embeds: torch.Tensor,
|
||||
acoustic_embeds_len: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
|
||||
sample_ids = decoder_out.argmax(dim=-1)
|
||||
|
||||
return decoder_out, sample_ids, out_caches
|
||||
|
||||
def get_dummy_inputs(self, ):
|
||||
dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.enc_size)
|
||||
return dummy_inputs
|
||||
|
||||
def get_input_names(self):
|
||||
|
||||
return self.decoder.get_input_names()
|
||||
|
||||
def get_output_names(self):
|
||||
|
||||
return self.decoder.get_output_names()
|
||||
|
||||
def get_dynamic_axes(self):
|
||||
return self.decoder.get_dynamic_axes()
|
||||
|
||||
@ -8,6 +8,7 @@ from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM
|
||||
from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
|
||||
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
|
||||
from funasr.modules.embedding import StreamSinusoidalPositionEncoder
|
||||
|
||||
|
||||
class SANMEncoder(nn.Module):
|
||||
@ -21,6 +22,8 @@ class SANMEncoder(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.embed = model.embed
|
||||
if isinstance(self.embed, StreamSinusoidalPositionEncoder):
|
||||
self.embed = None
|
||||
self.model = model
|
||||
self.feats_dim = feats_dim
|
||||
self._output_size = model._output_size
|
||||
@ -63,8 +66,10 @@ class SANMEncoder(nn.Module):
|
||||
def forward(self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
online: bool = False
|
||||
):
|
||||
speech = speech * self._output_size ** 0.5
|
||||
if not online:
|
||||
speech = speech * self._output_size ** 0.5
|
||||
mask = self.make_pad_mask(speech_lengths)
|
||||
mask = self.prepare_mask(mask)
|
||||
if self.embed is None:
|
||||
|
||||
@ -64,14 +64,14 @@ class MultiHeadedAttentionSANM(nn.Module):
|
||||
return self.linear_out(context_layer) # (batch, time1, d_model)
|
||||
|
||||
|
||||
def preprocess_for_attn(x, mask, cache, pad_fn):
|
||||
def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
|
||||
x = x * mask
|
||||
x = x.transpose(1, 2)
|
||||
if cache is None:
|
||||
x = pad_fn(x)
|
||||
else:
|
||||
x = torch.cat((cache[:, :, 1:], x), dim=2)
|
||||
cache = x
|
||||
x = torch.cat((cache, x), dim=2)
|
||||
cache = x[:, :, -(kernel_size-1):]
|
||||
return x, cache
|
||||
|
||||
|
||||
@ -90,7 +90,7 @@ class MultiHeadedAttentionSANMDecoder(nn.Module):
|
||||
self.attn = None
|
||||
|
||||
def forward(self, inputs, mask, cache=None):
|
||||
x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn)
|
||||
x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
|
||||
@ -36,6 +36,17 @@ class CifPredictorV2(nn.Module):
|
||||
def forward(self, hidden: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
alphas, token_num = self.forward_cnn(hidden, mask)
|
||||
mask = mask.transpose(-1, -2).float()
|
||||
mask = mask.squeeze(-1)
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
def forward_cnn(self, hidden: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
):
|
||||
h = hidden
|
||||
context = h.transpose(1, 2)
|
||||
queries = self.pad(context)
|
||||
@ -49,12 +60,8 @@ class CifPredictorV2(nn.Module):
|
||||
alphas = alphas * mask
|
||||
alphas = alphas.squeeze(-1)
|
||||
token_num = alphas.sum(-1)
|
||||
|
||||
mask = mask.squeeze(-1)
|
||||
hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
|
||||
acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
|
||||
|
||||
return acoustic_embeds, token_num, alphas, cif_peak
|
||||
|
||||
return alphas, token_num
|
||||
|
||||
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
|
||||
b, t, d = hidden.size()
|
||||
@ -285,4 +292,4 @@ def cif_wo_hidden(alphas, threshold: float):
|
||||
integrate)
|
||||
|
||||
fires = torch.stack(list_fires, 1)
|
||||
return fires
|
||||
return fires
|
||||
|
||||
@ -185,7 +185,9 @@ Introduction to command parameters:
|
||||
--port: the port number of the server listener.
|
||||
--wav-path: the audio input. Input can be a path to a wav file or a wav.scp file (a Kaldi-formatted wav list in which each line includes a wav_id followed by a tab and a wav_path).
|
||||
--is-ssl: whether to use SSL encryption. The default is to use SSL.
|
||||
--mode: offline mode.
|
||||
--mode: 2pass.
|
||||
--thread-num 1
|
||||
|
||||
```
|
||||
|
||||
### Custom client
|
||||
@ -194,7 +196,9 @@ If you want to define your own client, the Websocket communication protocol is a
|
||||
|
||||
```text
|
||||
# First communication
|
||||
{"mode": "offline", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]}# Send wav data
|
||||
{"mode": "offline", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]}
|
||||
# Send wav data
|
||||
|
||||
Bytes data
|
||||
# Send end flag
|
||||
{"is_speaking": False}
|
||||
|
||||
@ -76,7 +76,7 @@ Command parameter instructions:
|
||||
|
||||
After entering the samples/cpp directory, you can test it with CPP. The command is as follows:
|
||||
```shell
|
||||
./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path ../audio/asr_example.wav
|
||||
./funasr-wss-client-2pass --server-ip 127.0.0.1 --port 10095 --wav-path ../audio/asr_example.wav
|
||||
```
|
||||
|
||||
Command parameter description:
|
||||
|
||||
@ -1,51 +1,44 @@
|
||||
# Copyright 2018 gRPC authors.
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
# Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# cmake build file for C++ paraformer example.
|
||||
# Assumes protobuf and gRPC have been installed using cmake.
|
||||
# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build
|
||||
# that automatically builds all the dependencies before building paraformer.
|
||||
# 2023 by burkliu(刘柏基) liubaiji@xverse.cn
|
||||
|
||||
cmake_minimum_required(VERSION 3.10)
|
||||
|
||||
project(ASR C CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CMAKE_VERBOSE_MAKEFILE on)
|
||||
set(BUILD_TESTING OFF)
|
||||
|
||||
include(common.cmake)
|
||||
|
||||
# Proto file
|
||||
get_filename_component(rg_proto "../python/grpc/proto/paraformer.proto" ABSOLUTE)
|
||||
get_filename_component(rg_proto_path "${rg_proto}" PATH)
|
||||
get_filename_component(rg_proto ../python/grpc/proto/paraformer.proto ABSOLUTE)
|
||||
get_filename_component(rg_proto_path ${rg_proto} PATH)
|
||||
|
||||
# Generated sources
|
||||
set(rg_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.cc")
|
||||
set(rg_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.h")
|
||||
set(rg_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.cc")
|
||||
set(rg_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.h")
|
||||
set(rg_proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.cc)
|
||||
set(rg_proto_hdrs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.h)
|
||||
set(rg_grpc_srcs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.cc)
|
||||
set(rg_grpc_hdrs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.h)
|
||||
add_custom_command(
|
||||
OUTPUT "${rg_proto_srcs}" "${rg_proto_hdrs}" "${rg_grpc_srcs}" "${rg_grpc_hdrs}"
|
||||
COMMAND ${_PROTOBUF_PROTOC}
|
||||
ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
--cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
|
||||
-I "${rg_proto_path}"
|
||||
--plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
|
||||
"${rg_proto}"
|
||||
DEPENDS "${rg_proto}")
|
||||
OUTPUT ${rg_proto_srcs} ${rg_proto_hdrs} ${rg_grpc_srcs} ${rg_grpc_hdrs}
|
||||
COMMAND ${_PROTOBUF_PROTOC}
|
||||
ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--cpp_out ${CMAKE_CURRENT_BINARY_DIR}
|
||||
-I ${rg_proto_path}
|
||||
--plugin=protoc-gen-grpc=${_GRPC_CPP_PLUGIN_EXECUTABLE}
|
||||
${rg_proto}
|
||||
DEPENDS ${rg_proto})
|
||||
|
||||
# Include generated *.pb.h files
|
||||
include_directories("${CMAKE_CURRENT_BINARY_DIR}")
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
link_directories(${ONNXRUNTIME_DIR}/lib)
|
||||
link_directories(${FFMPEG_DIR}/lib)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
|
||||
@ -53,33 +46,21 @@ include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-nativ
|
||||
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
|
||||
add_subdirectory("../onnxruntime/src" onnx_src)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/src src)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
|
||||
set(BUILD_TESTING OFF)
|
||||
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
|
||||
|
||||
# rg_grpc_proto
|
||||
add_library(rg_grpc_proto
|
||||
${rg_grpc_srcs}
|
||||
${rg_grpc_hdrs}
|
||||
${rg_proto_srcs}
|
||||
${rg_proto_hdrs})
|
||||
add_library(rg_grpc_proto ${rg_grpc_srcs} ${rg_grpc_hdrs} ${rg_proto_srcs} ${rg_proto_hdrs})
|
||||
|
||||
target_link_libraries(rg_grpc_proto
|
||||
target_link_libraries(rg_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF})
|
||||
|
||||
add_executable(paraformer-server paraformer-server.cc)
|
||||
target_link_libraries(paraformer-server
|
||||
rg_grpc_proto
|
||||
funasr
|
||||
${EXTRA_LIBS}
|
||||
${_REFLECTION}
|
||||
${_GRPC_GRPCPP}
|
||||
${_PROTOBUF_LIBPROTOBUF})
|
||||
|
||||
foreach(_target
|
||||
paraformer-server)
|
||||
add_executable(${_target}
|
||||
"${_target}.cc")
|
||||
target_link_libraries(${_target}
|
||||
rg_grpc_proto
|
||||
funasr
|
||||
${EXTRA_LIBS}
|
||||
${_REFLECTION}
|
||||
${_GRPC_GRPCPP}
|
||||
${_PROTOBUF_LIBPROTOBUF})
|
||||
endforeach()
|
||||
|
||||
@ -2,17 +2,20 @@
|
||||
|
||||
## For the Server
|
||||
|
||||
### Build [onnxruntime](./onnxruntime_cpp.md) as it's document
|
||||
### 1. Build [onnxruntime](../websocket/readme.md) as it's document
|
||||
|
||||
### Compile and install grpc v1.52.0 in case of grpc bugs
|
||||
```
|
||||
export GRPC_INSTALL_DIR=/data/soft/grpc
|
||||
export PKG_CONFIG_PATH=$GRPC_INSTALL_DIR/lib/pkgconfig
|
||||
### 2. Compile and install grpc v1.52.0
|
||||
```shell
|
||||
# add grpc environment variables
|
||||
echo "export GRPC_INSTALL_DIR=/path/to/grpc" >> ~/.bashrc
|
||||
echo "export PKG_CONFIG_PATH=\$GRPC_INSTALL_DIR/lib/pkgconfig" >> ~/.bashrc
|
||||
echo "export PATH=\$GRPC_INSTALL_DIR/bin/:\$PKG_CONFIG_PATH:\$PATH" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
|
||||
# install grpc
|
||||
git clone --recurse-submodules -b v1.52.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc
|
||||
|
||||
git clone -b v1.52.0 --depth=1 https://github.com/grpc/grpc.git
|
||||
cd grpc
|
||||
git submodule update --init --recursive
|
||||
|
||||
mkdir -p cmake/build
|
||||
pushd cmake/build
|
||||
cmake -DgRPC_INSTALL=ON \
|
||||
@ -22,182 +25,57 @@ cmake -DgRPC_INSTALL=ON \
|
||||
make
|
||||
make install
|
||||
popd
|
||||
|
||||
echo "export GRPC_INSTALL_DIR=/data/soft/grpc" >> ~/.bashrc
|
||||
echo "export PKG_CONFIG_PATH=\$GRPC_INSTALL_DIR/lib/pkgconfig" >> ~/.bashrc
|
||||
echo "export PATH=\$GRPC_INSTALL_DIR/bin/:\$PKG_CONFIG_PATH:\$PATH" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
```
|
||||
|
||||
### Compile and start grpc onnx paraformer server
|
||||
```
|
||||
# set -DONNXRUNTIME_DIR=/path/to/asrmodel/onnxruntime-linux-x64-1.14.0
|
||||
./rebuild.sh
|
||||
### 3. Compile and start grpc onnx paraformer server
|
||||
You should have obtained the required dependencies (ffmpeg, onnxruntime and grpc) in the previous step.
|
||||
|
||||
If no, run [download_ffmpeg](../onnxruntime/third_party/download_ffmpeg.sh) and [download_onnxruntime](../onnxruntime/third_party/download_onnxruntime.sh)
|
||||
|
||||
```shell
|
||||
cd /cfs/user/burkliu/work2023/FunASR/funasr/runtime/grpc
|
||||
./build.sh
|
||||
```
|
||||
|
||||
### Start grpc paraformer server
|
||||
```
|
||||
### 4. Download paraformer model
|
||||
To do.
|
||||
|
||||
### 5. Start grpc paraformer server
|
||||
```shell
|
||||
# run as default
|
||||
./run_server.sh
|
||||
|
||||
# or run server directly
|
||||
./build/bin/paraformer-server \
|
||||
--port-id <string> \
|
||||
--offline-model-dir <string> \
|
||||
--online-model-dir <string> \
|
||||
--quantize <string> \
|
||||
--vad-dir <string> \
|
||||
--vad-quant <string> \
|
||||
--punc-dir <string> \
|
||||
--punc-quant <string>
|
||||
|
||||
./cmake/build/paraformer-server --port-id <string> [--punc-quant <string>]
|
||||
[--punc-dir <string>] [--vad-quant <string>]
|
||||
[--vad-dir <string>] [--quantize <string>]
|
||||
--model-dir <string> [--] [--version] [-h]
|
||||
Where:
|
||||
--port-id <string>
|
||||
(required) port id
|
||||
--model-dir <string>
|
||||
(required) the asr model path, which contains model.onnx, config.yaml, am.mvn
|
||||
--quantize <string>
|
||||
false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
|
||||
--port-id <string> (required) the port server listen to
|
||||
|
||||
--vad-dir <string>
|
||||
the vad model path, which contains model.onnx, vad.yaml, vad.mvn
|
||||
--vad-quant <string>
|
||||
false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
|
||||
--offline-model-dir <string> (required) the offline asr model path
|
||||
--online-model-dir <string> (required) the online asr model path
|
||||
--quantize <string> (optional) false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
|
||||
|
||||
--punc-dir <string>
|
||||
the punc model path, which contains model.onnx, punc.yaml
|
||||
--punc-quant <string>
|
||||
false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
|
||||
|
||||
Required: --port-id <string> --model-dir <string>
|
||||
If use vad, please add: --vad-dir <string>
|
||||
If use punc, please add: --punc-dir <string>
|
||||
--vad-dir <string> (required) the vad model path
|
||||
--vad-quant <string> (optional) false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
|
||||
|
||||
--punc-dir <string> (required) the punc model path
|
||||
--punc-quant <string> (optional) false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
|
||||
```
|
||||
|
||||
## For the client
|
||||
Currently we only support python grpc server.
|
||||
|
||||
### Install the requirements as in [grpc-python](./docs/grpc_python.md)
|
||||
|
||||
```shell
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
cd funasr/runtime/python/grpc
|
||||
pip install -r requirements_client.txt
|
||||
```
|
||||
|
||||
### Generate protobuf file
|
||||
Run on server, the two generated pb files are both used for server and client
|
||||
|
||||
```shell
|
||||
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
|
||||
# regenerate it only when you make changes to ./proto/paraformer.proto file.
|
||||
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
|
||||
```
|
||||
|
||||
### Start grpc client
|
||||
```
|
||||
# Start client.
|
||||
python grpc_main_client_mic.py --host 127.0.0.1 --port 10095
|
||||
```
|
||||
|
||||
[//]: # (```)
|
||||
|
||||
[//]: # (# go to ../python/grpc to find this package)
|
||||
|
||||
[//]: # (import paraformer_pb2)
|
||||
|
||||
[//]: # ()
|
||||
[//]: # ()
|
||||
[//]: # (class RecognizeStub:)
|
||||
|
||||
[//]: # ( def __init__(self, channel):)
|
||||
|
||||
[//]: # ( self.Recognize = channel.stream_stream()
|
||||
|
||||
[//]: # ( '/paraformer.ASR/Recognize',)
|
||||
|
||||
[//]: # ( request_serializer=paraformer_pb2.Request.SerializeToString,)
|
||||
|
||||
[//]: # ( response_deserializer=paraformer_pb2.Response.FromString,)
|
||||
|
||||
[//]: # ( ))
|
||||
|
||||
[//]: # ()
|
||||
[//]: # ()
|
||||
[//]: # (async def send(channel, data, speaking, isEnd):)
|
||||
|
||||
[//]: # ( stub = RecognizeStub(channel))
|
||||
|
||||
[//]: # ( req = paraformer_pb2.Request())
|
||||
|
||||
[//]: # ( if data:)
|
||||
|
||||
[//]: # ( req.audio_data = data)
|
||||
|
||||
[//]: # ( req.user = 'zz')
|
||||
|
||||
[//]: # ( req.language = 'zh-CN')
|
||||
|
||||
[//]: # ( req.speaking = speaking)
|
||||
|
||||
[//]: # ( req.isEnd = isEnd)
|
||||
|
||||
[//]: # ( q = queue.SimpleQueue())
|
||||
|
||||
[//]: # ( q.put(req))
|
||||
|
||||
[//]: # ( return stub.Recognize(iter(q.get, None)))
|
||||
|
||||
[//]: # ()
|
||||
[//]: # (# send the audio data once)
|
||||
|
||||
[//]: # (async def grpc_rec(data, grpc_uri):)
|
||||
|
||||
[//]: # ( with grpc.insecure_channel(grpc_uri) as channel:)
|
||||
|
||||
[//]: # ( b = time.time())
|
||||
|
||||
[//]: # ( response = await send(channel, data, False, False))
|
||||
|
||||
[//]: # ( resp = response.next())
|
||||
|
||||
[//]: # ( text = '')
|
||||
|
||||
[//]: # ( if 'decoding' == resp.action:)
|
||||
|
||||
[//]: # ( resp = response.next())
|
||||
|
||||
[//]: # ( if 'finish' == resp.action:)
|
||||
|
||||
[//]: # ( text = json.loads(resp.sentence)['text'])
|
||||
|
||||
[//]: # ( response = await send(channel, None, False, True))
|
||||
|
||||
[//]: # ( return {)
|
||||
|
||||
[//]: # ( 'text': text,)
|
||||
|
||||
[//]: # ( 'time': time.time() - b,)
|
||||
|
||||
[//]: # ( })
|
||||
|
||||
[//]: # ()
|
||||
[//]: # (async def test():)
|
||||
|
||||
[//]: # ( # fc = FunAsrGrpcClient('127.0.0.1', 9900))
|
||||
|
||||
[//]: # ( # t = await fc.rec(wav.tobytes()))
|
||||
|
||||
[//]: # ( # print(t))
|
||||
|
||||
[//]: # ( wav, _ = sf.read('z-10s.wav', dtype='int16'))
|
||||
|
||||
[//]: # ( uri = '127.0.0.1:9900')
|
||||
|
||||
[//]: # ( res = await grpc_rec(wav.tobytes(), uri))
|
||||
|
||||
[//]: # ( print(res))
|
||||
|
||||
[//]: # ()
|
||||
[//]: # ()
|
||||
[//]: # (if __name__ == '__main__':)
|
||||
|
||||
[//]: # ( asyncio.run(test()))
|
||||
|
||||
[//]: # ()
|
||||
[//]: # (```)
|
||||
Install the requirements as in [grpc-python](../python/grpc/Readme.md)
|
||||
|
||||
|
||||
## Acknowledge
|
||||
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
|
||||
2. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
|
||||
2. We acknowledge burkliu (刘柏基, liubaiji@xverse.cn) for contributing the grpc service.
|
||||
|
||||
15
funasr/runtime/grpc/build.sh
Executable file
15
funasr/runtime/grpc/build.sh
Executable file
@ -0,0 +1,15 @@
|
||||
#!/bin/bash
|
||||
|
||||
mode=debug #[debug|release]
|
||||
onnxruntime_dir=`pwd`/../onnxruntime/onnxruntime-linux-x64-1.14.0
|
||||
ffmpeg_dir=`pwd`/../onnxruntime/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared
|
||||
|
||||
|
||||
rm build -rf
|
||||
mkdir -p build
|
||||
cd build
|
||||
|
||||
cmake -DCMAKE_BUILD_TYPE=$mode ../ -DONNXRUNTIME_DIR=$onnxruntime_dir -DFFMPEG_DIR=$ffmpeg_dir
|
||||
cmake --build . -j 4
|
||||
|
||||
echo "Build server successfully!"
|
||||
@ -1,235 +1,261 @@
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
|
||||
|
||||
#include <grpc/grpc.h>
|
||||
#include <grpcpp/server.h>
|
||||
#include <grpcpp/server_builder.h>
|
||||
#include <grpcpp/server_context.h>
|
||||
#include <grpcpp/security/server_credentials.h>
|
||||
|
||||
#include "paraformer.grpc.pb.h"
|
||||
#include "paraformer-server.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
using grpc::ServerContext;
|
||||
using grpc::ServerReader;
|
||||
using grpc::ServerReaderWriter;
|
||||
using grpc::ServerWriter;
|
||||
using grpc::Status;
|
||||
GrpcEngine::GrpcEngine(
|
||||
grpc::ServerReaderWriter<Response, Request>* stream,
|
||||
std::shared_ptr<FUNASR_HANDLE> asr_handler)
|
||||
: stream_(std::move(stream)),
|
||||
asr_handler_(std::move(asr_handler)) {
|
||||
|
||||
using paraformer::Request;
|
||||
using paraformer::Response;
|
||||
using paraformer::ASR;
|
||||
|
||||
ASRServicer::ASRServicer(std::map<std::string, std::string>& model_path) {
|
||||
AsrHanlde=FunOfflineInit(model_path, 1);
|
||||
std::cout << "ASRServicer init" << std::endl;
|
||||
init_flag = 0;
|
||||
request_ = std::make_shared<Request>();
|
||||
}
|
||||
|
||||
void ASRServicer::clear_states(const std::string& user) {
|
||||
clear_buffers(user);
|
||||
clear_transcriptions(user);
|
||||
}
|
||||
void GrpcEngine::DecodeThreadFunc() {
|
||||
FUNASR_HANDLE tpass_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size_);
|
||||
int step = (sampling_rate_ * step_duration_ms_ / 1000) * 2; // int16 = 2bytes;
|
||||
std::vector<std::vector<std::string>> punc_cache(2);
|
||||
|
||||
void ASRServicer::clear_buffers(const std::string& user) {
|
||||
if (client_buffers.count(user)) {
|
||||
client_buffers.erase(user);
|
||||
}
|
||||
}
|
||||
bool is_final = false;
|
||||
std::string online_result = "";
|
||||
std::string tpass_result = "";
|
||||
|
||||
void ASRServicer::clear_transcriptions(const std::string& user) {
|
||||
if (client_transcription.count(user)) {
|
||||
client_transcription.erase(user);
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Decoder init, start decoding loop with mode";
|
||||
|
||||
void ASRServicer::disconnect(const std::string& user) {
|
||||
clear_states(user);
|
||||
std::cout << "Disconnecting user: " << user << std::endl;
|
||||
}
|
||||
while (true) {
|
||||
if (audio_buffer_.length() > step || is_end_) {
|
||||
if (audio_buffer_.length() <= step && is_end_) {
|
||||
is_final = true;
|
||||
step = audio_buffer_.length();
|
||||
}
|
||||
|
||||
grpc::Status ASRServicer::Recognize(
|
||||
grpc::ServerContext* context,
|
||||
grpc::ServerReaderWriter<Response, Request>* stream) {
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
|
||||
tpass_online_handler,
|
||||
audio_buffer_.c_str(),
|
||||
step,
|
||||
punc_cache,
|
||||
is_final,
|
||||
sampling_rate_,
|
||||
encoding_,
|
||||
mode_);
|
||||
audio_buffer_ = audio_buffer_.substr(step);
|
||||
|
||||
Request req;
|
||||
while (stream->Read(&req)) {
|
||||
if (req.isend()) {
|
||||
std::cout << "asr end" << std::endl;
|
||||
disconnect(req.user());
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
R"({"success": true, "detail": "asr end"})"
|
||||
);
|
||||
res.set_user(req.user());
|
||||
res.set_action("terminate");
|
||||
res.set_language(req.language());
|
||||
stream->Write(res);
|
||||
} else if (req.speaking()) {
|
||||
if (req.audio_data().size() > 0) {
|
||||
auto& buf = client_buffers[req.user()];
|
||||
buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
|
||||
}
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
R"({"success": true, "detail": "speaking"})"
|
||||
);
|
||||
res.set_user(req.user());
|
||||
res.set_action("speaking");
|
||||
res.set_language(req.language());
|
||||
stream->Write(res);
|
||||
} else if (!req.speaking()) {
|
||||
if (client_buffers.count(req.user()) == 0 && req.audio_data().size() == 0) {
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
R"({"success": true, "detail": "waiting_for_voice"})"
|
||||
);
|
||||
res.set_user(req.user());
|
||||
res.set_action("waiting");
|
||||
res.set_language(req.language());
|
||||
stream->Write(res);
|
||||
}else {
|
||||
auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
if (req.audio_data().size() > 0) {
|
||||
auto& buf = client_buffers[req.user()];
|
||||
buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
|
||||
}
|
||||
std::string tmp_data = this->client_buffers[req.user()];
|
||||
this->clear_states(req.user());
|
||||
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
R"({"success": true, "detail": "decoding data: " + std::to_string(tmp_data.length()) + " bytes"})"
|
||||
);
|
||||
int data_len_int = tmp_data.length();
|
||||
std::string data_len = std::to_string(data_len_int);
|
||||
std::stringstream ss;
|
||||
ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")" << R"("})";
|
||||
std::string result = ss.str();
|
||||
res.set_sentence(result);
|
||||
res.set_user(req.user());
|
||||
res.set_action("decoding");
|
||||
res.set_language(req.language());
|
||||
stream->Write(res);
|
||||
if (tmp_data.length() < 800) { //min input_len for asr model
|
||||
auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
std::string delay_str = std::to_string(end_time - begin_time);
|
||||
std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", error: data_is_not_long_enough" << std::endl;
|
||||
Response res;
|
||||
std::stringstream ss;
|
||||
std::string asr_result = "";
|
||||
ss << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})";
|
||||
std::string result = ss.str();
|
||||
res.set_sentence(result);
|
||||
res.set_user(req.user());
|
||||
res.set_action("finish");
|
||||
res.set_language(req.language());
|
||||
stream->Write(res);
|
||||
}
|
||||
else {
|
||||
FUNASR_RESULT Result= FunOfflineInferBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL, 16000);
|
||||
std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
|
||||
|
||||
auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
|
||||
std::string delay_str = std::to_string(end_time - begin_time);
|
||||
|
||||
std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl;
|
||||
Response res;
|
||||
std::stringstream ss;
|
||||
ss << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})";
|
||||
std::string result = ss.str();
|
||||
res.set_sentence(result);
|
||||
res.set_user(req.user());
|
||||
res.set_action("finish");
|
||||
res.set_language(req.language());
|
||||
|
||||
stream->Write(res);
|
||||
}
|
||||
}
|
||||
}else {
|
||||
Response res;
|
||||
res.set_sentence(
|
||||
R"({"success": false, "detail": "error, no condition matched! Unknown reason."})"
|
||||
);
|
||||
res.set_user(req.user());
|
||||
res.set_action("terminate");
|
||||
res.set_language(req.language());
|
||||
stream->Write(res);
|
||||
if (result) {
|
||||
std::string online_message = FunASRGetResult(result, 0);
|
||||
online_result += online_message;
|
||||
if(online_message != ""){
|
||||
Response response;
|
||||
response.set_mode(DecodeMode::online);
|
||||
response.set_text(online_message);
|
||||
response.set_is_final(is_final);
|
||||
stream_->Write(response);
|
||||
LOG(INFO) << "send online results: " << online_message;
|
||||
}
|
||||
std::string tpass_message = FunASRGetTpassResult(result, 0);
|
||||
tpass_result += tpass_message;
|
||||
if(tpass_message != ""){
|
||||
Response response;
|
||||
response.set_mode(DecodeMode::two_pass);
|
||||
response.set_text(tpass_message);
|
||||
response.set_is_final(is_final);
|
||||
stream_->Write(response);
|
||||
LOG(INFO) << "send offline results: " << tpass_message;
|
||||
}
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
|
||||
if (is_final) {
|
||||
FunTpassOnlineUninit(tpass_online_handler);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Status::OK;
|
||||
sleep(0.001);
|
||||
}
|
||||
}
|
||||
|
||||
void RunServer(std::map<std::string, std::string>& model_path) {
|
||||
std::string port;
|
||||
try{
|
||||
port = model_path.at(PORT_ID);
|
||||
}catch(std::exception const &e){
|
||||
printf("Error when read port.\n");
|
||||
exit(0);
|
||||
void GrpcEngine::OnSpeechStart() {
|
||||
if (request_->chunk_size_size() == 3) {
|
||||
for (int i = 0; i < 3; i++) {
|
||||
chunk_size_[i] = int(request_->chunk_size(i));
|
||||
}
|
||||
std::string server_address;
|
||||
server_address = "0.0.0.0:" + port;
|
||||
ASRServicer service(model_path);
|
||||
}
|
||||
std::string chunk_size_str;
|
||||
for (int i = 0; i < 3; i++) {
|
||||
chunk_size_str = " " + chunk_size_[i];
|
||||
}
|
||||
LOG(INFO) << "chunk_size is" << chunk_size_str;
|
||||
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
std::cout << "Server listening on " << server_address << std::endl;
|
||||
server->Wait();
|
||||
if (request_->sampling_rate() != 0) {
|
||||
sampling_rate_ = request_->sampling_rate();
|
||||
}
|
||||
LOG(INFO) << "sampling_rate is " << sampling_rate_;
|
||||
|
||||
switch(request_->wav_format()) {
|
||||
case WavFormat::pcm: encoding_ = "pcm";
|
||||
}
|
||||
LOG(INFO) << "encoding is " << encoding_;
|
||||
|
||||
std::string mode_str;
|
||||
switch(request_->mode()) {
|
||||
case DecodeMode::offline:
|
||||
mode_ = ASR_OFFLINE;
|
||||
mode_str = "offline";
|
||||
break;
|
||||
case DecodeMode::online:
|
||||
mode_ = ASR_ONLINE;
|
||||
mode_str = "online";
|
||||
break;
|
||||
case DecodeMode::two_pass:
|
||||
mode_ = ASR_TWO_PASS;
|
||||
mode_str = "two_pass";
|
||||
break;
|
||||
}
|
||||
LOG(INFO) << "decode mode is " << mode_str;
|
||||
|
||||
decode_thread_ = std::make_shared<std::thread>(&GrpcEngine::DecodeThreadFunc, this);
|
||||
is_start_ = true;
|
||||
}
|
||||
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& model_path)
|
||||
{
|
||||
if (value_arg.isSet()){
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
void GrpcEngine::OnSpeechData() {
|
||||
audio_buffer_ += request_->audio_data();
|
||||
}
|
||||
|
||||
void GrpcEngine::OnSpeechEnd() {
|
||||
is_end_ = true;
|
||||
LOG(INFO) << "Read all pcm data, wait for decoding thread";
|
||||
if (decode_thread_ != nullptr) {
|
||||
decode_thread_->join();
|
||||
}
|
||||
}
|
||||
|
||||
void GrpcEngine::operator()() {
|
||||
try {
|
||||
LOG(INFO) << "start engine main loop";
|
||||
while (stream_->Read(request_.get())) {
|
||||
LOG(INFO) << "receive data";
|
||||
if (!is_start_) {
|
||||
OnSpeechStart();
|
||||
}
|
||||
OnSpeechData();
|
||||
if (request_->is_final()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
OnSpeechEnd();
|
||||
LOG(INFO) << "Connect finish";
|
||||
} catch (std::exception const& e) {
|
||||
LOG(ERROR) << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
GrpcService::GrpcService(std::map<std::string, std::string>& config, int onnx_thread)
|
||||
: config_(config) {
|
||||
|
||||
asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunTpassInit(config_, onnx_thread)));
|
||||
LOG(INFO) << "GrpcService model loaded";
|
||||
|
||||
std::vector<int> chunk_size = {5, 10, 5};
|
||||
FUNASR_HANDLE tmp_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size);
|
||||
int sampling_rate = 16000;
|
||||
int buffer_len = sampling_rate * 1;
|
||||
std::string tmp_data(buffer_len, '0');
|
||||
std::vector<std::vector<std::string>> punc_cache(2);
|
||||
bool is_final = true;
|
||||
std::string encoding = "pcm";
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
|
||||
tmp_online_handler,
|
||||
tmp_data.c_str(),
|
||||
buffer_len,
|
||||
punc_cache,
|
||||
is_final,
|
||||
buffer_len,
|
||||
encoding,
|
||||
ASR_TWO_PASS);
|
||||
if (result) {
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
FunTpassOnlineUninit(tmp_online_handler);
|
||||
LOG(INFO) << "GrpcService model warmup";
|
||||
}
|
||||
|
||||
grpc::Status GrpcService::Recognize(
|
||||
grpc::ServerContext* context,
|
||||
grpc::ServerReaderWriter<Response, Request>* stream) {
|
||||
LOG(INFO) << "Get Recognize request";
|
||||
GrpcEngine engine(
|
||||
stream,
|
||||
asr_handler_
|
||||
);
|
||||
|
||||
std::thread t(std::move(engine));
|
||||
t.join();
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& config) {
|
||||
if (value_arg.isSet()) {
|
||||
config.insert({key, value_arg.getValue()});
|
||||
LOG(INFO) << key << " : " << value_arg.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
FLAGS_logtostderr = true;
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
|
||||
TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
|
||||
|
||||
TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
|
||||
cmd.add(offline_model_dir);
|
||||
cmd.add(online_model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(onnx_thread);
|
||||
cmd.add(port_id);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(port_id);
|
||||
cmd.parse(argc, argv);
|
||||
std::map<std::string, std::string> config;
|
||||
GetValue(offline_model_dir, OFFLINE_MODEL_DIR, config);
|
||||
GetValue(online_model_dir, ONLINE_MODEL_DIR, config);
|
||||
GetValue(quantize, QUANTIZE, config);
|
||||
GetValue(vad_dir, VAD_DIR, config);
|
||||
GetValue(vad_quant, VAD_QUANT, config);
|
||||
GetValue(punc_dir, PUNC_DIR, config);
|
||||
GetValue(punc_quant, PUNC_QUANT, config);
|
||||
GetValue(port_id, PORT_ID, config);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(port_id, PORT_ID, model_path);
|
||||
std::string port;
|
||||
try {
|
||||
port = config.at(PORT_ID);
|
||||
} catch(std::exception const &e) {
|
||||
LOG(INFO) << ("Error when read port.");
|
||||
exit(0);
|
||||
}
|
||||
std::string server_address;
|
||||
server_address = "0.0.0.0:" + port;
|
||||
GrpcService service(config, onnx_thread);
|
||||
|
||||
RunServer(model_path);
|
||||
return 0;
|
||||
grpc::ServerBuilder builder;
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||
LOG(INFO) << "Server listening on " << server_address;
|
||||
server->Wait();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -1,55 +1,65 @@
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2023 by burkliu(刘柏基) liubaiji@xverse.cn */
|
||||
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <grpc/grpc.h>
|
||||
#include <grpcpp/server.h>
|
||||
#include <grpcpp/server_builder.h>
|
||||
#include <grpcpp/server_context.h>
|
||||
#include <grpcpp/security/server_credentials.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <chrono>
|
||||
|
||||
#include "grpcpp/server_builder.h"
|
||||
#include "paraformer.grpc.pb.h"
|
||||
#include "funasrruntime.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
#include "glog/logging.h"
|
||||
|
||||
|
||||
using grpc::Server;
|
||||
using grpc::ServerBuilder;
|
||||
using grpc::ServerContext;
|
||||
using grpc::ServerReader;
|
||||
using grpc::ServerReaderWriter;
|
||||
using grpc::ServerWriter;
|
||||
using grpc::Status;
|
||||
|
||||
|
||||
using paraformer::WavFormat;
|
||||
using paraformer::DecodeMode;
|
||||
using paraformer::Request;
|
||||
using paraformer::Response;
|
||||
using paraformer::ASR;
|
||||
|
||||
typedef struct
|
||||
{
|
||||
std::string msg;
|
||||
float snippet_time;
|
||||
}FUNASR_RECOG_RESULT;
|
||||
std::string msg;
|
||||
float snippet_time;
|
||||
} FUNASR_RECOG_RESULT;
|
||||
|
||||
class ASRServicer final : public ASR::Service {
|
||||
private:
|
||||
int init_flag;
|
||||
std::unordered_map<std::string, std::string> client_buffers;
|
||||
std::unordered_map<std::string, std::string> client_transcription;
|
||||
class GrpcEngine {
|
||||
public:
|
||||
GrpcEngine(grpc::ServerReaderWriter<Response, Request>* stream, std::shared_ptr<FUNASR_HANDLE> asr_handler);
|
||||
void operator()();
|
||||
|
||||
public:
|
||||
ASRServicer(std::map<std::string, std::string>& model_path);
|
||||
void clear_states(const std::string& user);
|
||||
void clear_buffers(const std::string& user);
|
||||
void clear_transcriptions(const std::string& user);
|
||||
void disconnect(const std::string& user);
|
||||
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
|
||||
FUNASR_HANDLE AsrHanlde;
|
||||
|
||||
private:
|
||||
void DecodeThreadFunc();
|
||||
void OnSpeechStart();
|
||||
void OnSpeechData();
|
||||
void OnSpeechEnd();
|
||||
|
||||
grpc::ServerReaderWriter<Response, Request>* stream_;
|
||||
std::shared_ptr<Request> request_;
|
||||
std::shared_ptr<Response> response_;
|
||||
std::shared_ptr<FUNASR_HANDLE> asr_handler_;
|
||||
std::string audio_buffer_;
|
||||
std::shared_ptr<std::thread> decode_thread_ = nullptr;
|
||||
bool is_start_ = false;
|
||||
bool is_end_ = false;
|
||||
|
||||
std::vector<int> chunk_size_ = {5, 10, 5};
|
||||
int sampling_rate_ = 16000;
|
||||
std::string encoding_;
|
||||
ASR_TYPE mode_ = ASR_TWO_PASS;
|
||||
int step_duration_ms_ = 100;
|
||||
};
|
||||
|
||||
class GrpcService final : public ASR::Service {
|
||||
public:
|
||||
GrpcService(std::map<std::string, std::string>& config, int num_thread);
|
||||
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
|
||||
|
||||
private:
|
||||
std::map<std::string, std::string> config_;
|
||||
std::shared_ptr<FUNASR_HANDLE> asr_handler_;
|
||||
};
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
rm cmake -rf
|
||||
mkdir -p cmake/build
|
||||
|
||||
cd cmake/build
|
||||
|
||||
cmake -DCMAKE_BUILD_TYPE=release ../.. -DONNXRUNTIME_DIR=/data/asrmodel/onnxruntime-linux-x64-1.14.0
|
||||
make
|
||||
|
||||
|
||||
echo "Build cmake/build/paraformer_server successfully!"
|
||||
12
funasr/runtime/grpc/run_server.sh
Executable file
12
funasr/runtime/grpc/run_server.sh
Executable file
@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
./build/bin/paraformer-server \
|
||||
--port-id 10100 \
|
||||
--offline-model-dir funasr_models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
|
||||
--online-model-dir funasr_models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online \
|
||||
--quantize true \
|
||||
--vad-dir funasr_models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
|
||||
--vad-quant true \
|
||||
--punc-dir funasr_models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727 \
|
||||
--punc-quant true \
|
||||
2>&1
|
||||
@ -9,6 +9,9 @@ target_link_libraries(funasr-onnx-offline-vad PUBLIC funasr)
|
||||
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp")
|
||||
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp")
|
||||
target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
|
||||
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
|
||||
|
||||
@ -17,3 +20,16 @@ target_link_libraries(funasr-onnx-online-punc PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
|
||||
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp")
|
||||
target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp")
|
||||
target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
|
||||
|
||||
add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp")
|
||||
target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
|
||||
|
||||
# include_directories(${FFMPEG_DIR}/include)
|
||||
# add_executable(ff "ffmpeg.cpp")
|
||||
# target_link_libraries(ff PUBLIC avutil avcodec avformat swresample)
|
||||
|
||||
310
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
Normal file
310
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
Normal file
@ -0,0 +1,310 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <sys/time.h>
|
||||
#else
|
||||
#include <win_func.h>
|
||||
#endif
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <glog/logging.h>
|
||||
#include "funasrruntime.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
#include "audio.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
std::atomic<int> wav_index(0);
|
||||
std::mutex mtx;
|
||||
|
||||
bool is_target_file(const std::string& filename, const std::string target) {
|
||||
std::size_t pos = filename.find_last_of(".");
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
std::string extension = filename.substr(pos + 1);
|
||||
return (extension == target);
|
||||
}
|
||||
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
|
||||
{
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
}
|
||||
|
||||
|
||||
void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<string> wav_list, vector<string> wav_ids,
|
||||
float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_) {
|
||||
|
||||
struct timeval start, end;
|
||||
long seconds = 0;
|
||||
float n_total_length = 0.0f;
|
||||
long n_total_time = 0;
|
||||
|
||||
// init online features
|
||||
FUNASR_HANDLE tpass_online_handle=FunTpassOnlineInit(tpass_handle, chunk_size);
|
||||
|
||||
// warm up
|
||||
for (size_t i = 0; i < 2; i++)
|
||||
{
|
||||
int32_t sampling_rate_ = 16000;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_list[0].c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_list[0].c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[0];
|
||||
exit(-1);
|
||||
}
|
||||
}else if(is_target_file(wav_list[0].c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav2Char(wav_list[0].c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[0];
|
||||
exit(-1);
|
||||
}
|
||||
}else{
|
||||
if (!audio.FfmpegLoad(wav_list[0].c_str(), true)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[0];
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
char* speech_buff = audio.GetSpeechChar();
|
||||
int buff_len = audio.GetSpeechLen()*2;
|
||||
|
||||
int step = 1600*2;
|
||||
bool is_final = false;
|
||||
|
||||
std::vector<std::vector<string>> punc_cache(2);
|
||||
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
|
||||
if (sample_offset + step >= buff_len - 1) {
|
||||
step = buff_len - sample_offset;
|
||||
is_final = true;
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
|
||||
if (result)
|
||||
{
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (true) {
|
||||
// 使用原子变量获取索引并递增
|
||||
int i = wav_index.fetch_add(1);
|
||||
if (i >= wav_list.size()) {
|
||||
break;
|
||||
}
|
||||
int32_t sampling_rate_ = 16000;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_list[i].c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_list[i].c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[i];
|
||||
exit(-1);
|
||||
}
|
||||
}else if(is_target_file(wav_list[i].c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav2Char(wav_list[i].c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[i];
|
||||
exit(-1);
|
||||
}
|
||||
}else{
|
||||
if (!audio.FfmpegLoad(wav_list[i].c_str(), true)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[i];
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
char* speech_buff = audio.GetSpeechChar();
|
||||
int buff_len = audio.GetSpeechLen()*2;
|
||||
|
||||
int step = 1600*2;
|
||||
bool is_final = false;
|
||||
|
||||
string online_res="";
|
||||
string tpass_res="";
|
||||
std::vector<std::vector<string>> punc_cache(2);
|
||||
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
|
||||
if (sample_offset + step >= buff_len - 1) {
|
||||
step = buff_len - sample_offset;
|
||||
is_final = true;
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
n_total_time += taking_micros;
|
||||
|
||||
if (result)
|
||||
{
|
||||
string online_msg = FunASRGetResult(result, 0);
|
||||
online_res += online_msg;
|
||||
if(online_msg != ""){
|
||||
LOG(INFO)<< wav_ids[i] <<" : "<<online_msg;
|
||||
}
|
||||
string tpass_msg = FunASRGetTpassResult(result, 0);
|
||||
tpass_res += tpass_msg;
|
||||
if(tpass_msg != ""){
|
||||
LOG(INFO)<< wav_ids[i] <<" offline results : "<<tpass_msg;
|
||||
}
|
||||
float snippet_time = FunASRGetRetSnippetTime(result);
|
||||
n_total_length += snippet_time;
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG(ERROR) << ("No return data!\n");
|
||||
}
|
||||
}
|
||||
if(asr_mode_ == 2){
|
||||
LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " Final online results "<<" : "<<online_res;
|
||||
}
|
||||
if(asr_mode_==1){
|
||||
LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " Final online results "<<" : "<<tpass_res;
|
||||
}
|
||||
if(asr_mode_ == 0 || asr_mode_==2){
|
||||
LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " Final offline results " <<" : "<<tpass_res;
|
||||
}
|
||||
|
||||
}
|
||||
{
|
||||
lock_guard<mutex> guard(mtx);
|
||||
*total_length += n_total_length;
|
||||
if(*total_time < n_total_time){
|
||||
*total_time = n_total_time;
|
||||
}
|
||||
}
|
||||
FunTpassOnlineUninit(tpass_online_handle);
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
|
||||
TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
|
||||
cmd.add(offline_model_dir);
|
||||
cmd.add(online_model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(asr_mode);
|
||||
cmd.add(onnx_thread);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
|
||||
GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
GetValue(asr_mode, ASR_MODE, model_path);
|
||||
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, NULL);
|
||||
int thread_num = onnx_thread.getValue();
|
||||
int asr_mode_ = -1;
|
||||
if(model_path[ASR_MODE] == "offline"){
|
||||
asr_mode_ = 0;
|
||||
}else if(model_path[ASR_MODE] == "online"){
|
||||
asr_mode_ = 1;
|
||||
}else if(model_path[ASR_MODE] == "2pass"){
|
||||
asr_mode_ = 2;
|
||||
}else{
|
||||
LOG(ERROR) << "Wrong asr-mode : " << model_path[ASR_MODE];
|
||||
exit(-1);
|
||||
}
|
||||
FUNASR_HANDLE tpass_hanlde=FunTpassInit(model_path, thread_num);
|
||||
|
||||
if (!tpass_hanlde)
|
||||
{
|
||||
LOG(ERROR) << "FunTpassInit init failed";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
|
||||
|
||||
// read wav_path
|
||||
vector<string> wav_list;
|
||||
vector<string> wav_ids;
|
||||
string default_id = "wav_default_id";
|
||||
string wav_path_ = model_path.at(WAV_PATH);
|
||||
|
||||
if(is_target_file(wav_path_, "scp")){
|
||||
ifstream in(wav_path_);
|
||||
if (!in.is_open()) {
|
||||
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
|
||||
return 0;
|
||||
}
|
||||
string line;
|
||||
while(getline(in, line))
|
||||
{
|
||||
istringstream iss(line);
|
||||
string column1, column2;
|
||||
iss >> column1 >> column2;
|
||||
wav_list.emplace_back(column2);
|
||||
wav_ids.emplace_back(column1);
|
||||
}
|
||||
in.close();
|
||||
}else{
|
||||
wav_list.emplace_back(wav_path_);
|
||||
wav_ids.emplace_back(default_id);
|
||||
}
|
||||
|
||||
std::vector<int> chunk_size = {5,10,5};
|
||||
// 多线程测试
|
||||
float total_length = 0.0f;
|
||||
long total_time = 0;
|
||||
std::vector<std::thread> threads;
|
||||
|
||||
int rtf_threds = 5;
|
||||
for (int i = 0; i < rtf_threds; i++)
|
||||
{
|
||||
threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, &total_length, &total_time, i, (ASR_TYPE)asr_mode_));
|
||||
}
|
||||
|
||||
for (auto& thread : threads)
|
||||
{
|
||||
thread.join();
|
||||
}
|
||||
|
||||
LOG(INFO) << "total_time_wav " << (long)(total_length * 1000) << " ms";
|
||||
LOG(INFO) << "total_time_comput " << total_time / 1000 << " ms";
|
||||
LOG(INFO) << "total_rtf " << (double)total_time/ (total_length*1000000);
|
||||
LOG(INFO) << "speedup " << 1.0/((double)total_time/ (total_length*1000000));
|
||||
|
||||
FunTpassUninit(tpass_hanlde);
|
||||
return 0;
|
||||
}
|
||||
|
||||
217
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp
Normal file
217
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp
Normal file
@ -0,0 +1,217 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <sys/time.h>
|
||||
#else
|
||||
#include <win_func.h>
|
||||
#endif
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <glog/logging.h>
|
||||
#include "funasrruntime.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
#include "audio.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
bool is_target_file(const std::string& filename, const std::string target) {
|
||||
std::size_t pos = filename.find_last_of(".");
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
std::string extension = filename.substr(pos + 1);
|
||||
return (extension == target);
|
||||
}
|
||||
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
|
||||
{
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
|
||||
TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
|
||||
cmd.add(offline_model_dir);
|
||||
cmd.add(online_model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(asr_mode);
|
||||
cmd.add(onnx_thread);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
|
||||
GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
GetValue(asr_mode, ASR_MODE, model_path);
|
||||
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, NULL);
|
||||
int thread_num = onnx_thread.getValue();
|
||||
int asr_mode_ = -1;
|
||||
if(model_path[ASR_MODE] == "offline"){
|
||||
asr_mode_ = 0;
|
||||
}else if(model_path[ASR_MODE] == "online"){
|
||||
asr_mode_ = 1;
|
||||
}else if(model_path[ASR_MODE] == "2pass"){
|
||||
asr_mode_ = 2;
|
||||
}else{
|
||||
LOG(ERROR) << "Wrong asr-mode : " << model_path[ASR_MODE];
|
||||
exit(-1);
|
||||
}
|
||||
FUNASR_HANDLE tpass_handle=FunTpassInit(model_path, thread_num);
|
||||
|
||||
if (!tpass_handle)
|
||||
{
|
||||
LOG(ERROR) << "FunTpassInit init failed";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
|
||||
|
||||
// read wav_path
|
||||
vector<string> wav_list;
|
||||
vector<string> wav_ids;
|
||||
string default_id = "wav_default_id";
|
||||
string wav_path_ = model_path.at(WAV_PATH);
|
||||
|
||||
if(is_target_file(wav_path_, "scp")){
|
||||
ifstream in(wav_path_);
|
||||
if (!in.is_open()) {
|
||||
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
|
||||
return 0;
|
||||
}
|
||||
string line;
|
||||
while(getline(in, line))
|
||||
{
|
||||
istringstream iss(line);
|
||||
string column1, column2;
|
||||
iss >> column1 >> column2;
|
||||
wav_list.emplace_back(column2);
|
||||
wav_ids.emplace_back(column1);
|
||||
}
|
||||
in.close();
|
||||
}else{
|
||||
wav_list.emplace_back(wav_path_);
|
||||
wav_ids.emplace_back(default_id);
|
||||
}
|
||||
|
||||
// init online features
|
||||
std::vector<int> chunk_size = {5,10,5};
|
||||
FUNASR_HANDLE tpass_online_handle=FunTpassOnlineInit(tpass_handle, chunk_size);
|
||||
float snippet_time = 0.0f;
|
||||
long taking_micros = 0;
|
||||
for (int i = 0; i < wav_list.size(); i++) {
|
||||
auto& wav_file = wav_list[i];
|
||||
auto& wav_id = wav_ids[i];
|
||||
|
||||
int32_t sampling_rate_ = 16000;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_file.c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
exit(-1);
|
||||
}
|
||||
}else if(is_target_file(wav_file.c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
exit(-1);
|
||||
}
|
||||
}else{
|
||||
if (!audio.FfmpegLoad(wav_file.c_str(), true)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
char* speech_buff = audio.GetSpeechChar();
|
||||
int buff_len = audio.GetSpeechLen()*2;
|
||||
|
||||
int step = 1600*2;
|
||||
bool is_final = false;
|
||||
|
||||
string online_res="";
|
||||
string tpass_res="";
|
||||
std::vector<std::vector<string>> punc_cache(2);
|
||||
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
|
||||
if (sample_offset + step >= buff_len - 1) {
|
||||
step = buff_len - sample_offset;
|
||||
is_final = true;
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
|
||||
if (result)
|
||||
{
|
||||
string online_msg = FunASRGetResult(result, 0);
|
||||
online_res += online_msg;
|
||||
if(online_msg != ""){
|
||||
LOG(INFO)<< wav_id <<" : "<<online_msg;
|
||||
}
|
||||
string tpass_msg = FunASRGetTpassResult(result, 0);
|
||||
tpass_res += tpass_msg;
|
||||
if(tpass_msg != ""){
|
||||
LOG(INFO)<< wav_id <<" offline results : "<<tpass_msg;
|
||||
}
|
||||
snippet_time += FunASRGetRetSnippetTime(result);
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
}
|
||||
if(asr_mode_==2){
|
||||
LOG(INFO) << wav_id << " Final online results "<<" : "<<online_res;
|
||||
}
|
||||
if(asr_mode_==1){
|
||||
LOG(INFO) << wav_id << " Final online results "<<" : "<<tpass_res;
|
||||
}
|
||||
if(asr_mode_==0 || asr_mode_==2){
|
||||
LOG(INFO) << wav_id << " Final offline results " <<" : "<<tpass_res;
|
||||
}
|
||||
}
|
||||
|
||||
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
|
||||
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
|
||||
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
|
||||
FunTpassOnlineUninit(tpass_online_handle);
|
||||
FunTpassUninit(tpass_handle);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -40,6 +40,9 @@ void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wa
|
||||
for (size_t i = 0; i < 1; i++)
|
||||
{
|
||||
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, 16000);
|
||||
if(result){
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
}
|
||||
|
||||
while (true) {
|
||||
|
||||
174
funasr/runtime/onnxruntime/bin/funasr-onnx-online-asr.cpp
Normal file
174
funasr/runtime/onnxruntime/bin/funasr-onnx-online-asr.cpp
Normal file
@ -0,0 +1,174 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <sys/time.h>
|
||||
#else
|
||||
#include <win_func.h>
|
||||
#endif
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <glog/logging.h>
|
||||
#include "funasrruntime.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
#include "audio.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
bool is_target_file(const std::string& filename, const std::string target) {
|
||||
std::size_t pos = filename.find_last_of(".");
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
std::string extension = filename.substr(pos + 1);
|
||||
return (extension == target);
|
||||
}
|
||||
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
|
||||
{
|
||||
if (value_arg.isSet()){
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(wav_path);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, NULL);
|
||||
int thread_num = 1;
|
||||
FUNASR_HANDLE asr_handle=FunASRInit(model_path, thread_num, ASR_ONLINE);
|
||||
|
||||
if (!asr_handle)
|
||||
{
|
||||
LOG(ERROR) << "FunVad init failed";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
|
||||
|
||||
// read wav_path
|
||||
vector<string> wav_list;
|
||||
vector<string> wav_ids;
|
||||
string default_id = "wav_default_id";
|
||||
string wav_path_ = model_path.at(WAV_PATH);
|
||||
if(is_target_file(wav_path_, "scp")){
|
||||
ifstream in(wav_path_);
|
||||
if (!in.is_open()) {
|
||||
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
|
||||
return 0;
|
||||
}
|
||||
string line;
|
||||
while(getline(in, line))
|
||||
{
|
||||
istringstream iss(line);
|
||||
string column1, column2;
|
||||
iss >> column1 >> column2;
|
||||
wav_list.emplace_back(column2);
|
||||
wav_ids.emplace_back(column1);
|
||||
}
|
||||
in.close();
|
||||
}else{
|
||||
wav_list.emplace_back(wav_path_);
|
||||
wav_ids.emplace_back(default_id);
|
||||
}
|
||||
|
||||
// init online features
|
||||
FUNASR_HANDLE online_handle=FunASROnlineInit(asr_handle);
|
||||
float snippet_time = 0.0f;
|
||||
long taking_micros = 0;
|
||||
for (int i = 0; i < wav_list.size(); i++) {
|
||||
auto& wav_file = wav_list[i];
|
||||
auto& wav_id = wav_ids[i];
|
||||
|
||||
int32_t sampling_rate_ = -1;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_file.c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
exit(-1);
|
||||
}
|
||||
}else if(is_target_file(wav_file.c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
exit(-1);
|
||||
}
|
||||
}else{
|
||||
if (!audio.FfmpegLoad(wav_file.c_str(), true)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_file;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
char* speech_buff = audio.GetSpeechChar();
|
||||
int buff_len = audio.GetSpeechLen()*2;
|
||||
|
||||
int step = 9600*2;
|
||||
bool is_final = false;
|
||||
|
||||
string final_res="";
|
||||
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
|
||||
if (sample_offset + step >= buff_len - 1) {
|
||||
step = buff_len - sample_offset;
|
||||
is_final = true;
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
|
||||
if (result)
|
||||
{
|
||||
string msg = FunASRGetResult(result, 0);
|
||||
final_res += msg;
|
||||
LOG(INFO)<< wav_id <<" : "<<msg;
|
||||
snippet_time += FunASRGetRetSnippetTime(result);
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG(ERROR) << ("No return data!\n");
|
||||
}
|
||||
}
|
||||
LOG(INFO)<<"Final results " << wav_id <<" : "<<final_res;
|
||||
}
|
||||
|
||||
LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
|
||||
LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
|
||||
LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
|
||||
FunASRUninit(asr_handle);
|
||||
FunASRUninit(online_handle);
|
||||
return 0;
|
||||
}
|
||||
|
||||
278
funasr/runtime/onnxruntime/bin/funasr-onnx-online-rtf.cpp
Normal file
278
funasr/runtime/onnxruntime/bin/funasr-onnx-online-rtf.cpp
Normal file
@ -0,0 +1,278 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#include <sys/time.h>
|
||||
#else
|
||||
#include <win_func.h>
|
||||
#endif
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include "funasrruntime.h"
|
||||
#include "tclap/CmdLine.h"
|
||||
#include "com-define.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <map>
|
||||
#include "audio.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
std::atomic<int> wav_index(0);
|
||||
std::mutex mtx;
|
||||
|
||||
bool is_target_file(const std::string& filename, const std::string target) {
|
||||
std::size_t pos = filename.find_last_of(".");
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
std::string extension = filename.substr(pos + 1);
|
||||
return (extension == target);
|
||||
}
|
||||
|
||||
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids,
|
||||
float* total_length, long* total_time, int core_id) {
|
||||
|
||||
struct timeval start, end;
|
||||
long seconds = 0;
|
||||
float n_total_length = 0.0f;
|
||||
long n_total_time = 0;
|
||||
|
||||
// init online features
|
||||
FUNASR_HANDLE online_handle=FunASROnlineInit(asr_handle);
|
||||
|
||||
// warm up
|
||||
for (size_t i = 0; i < 10; i++)
|
||||
{
|
||||
int32_t sampling_rate_ = -1;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_list[0].c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_list[0].c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[0];
|
||||
exit(-1);
|
||||
}
|
||||
}else if(is_target_file(wav_list[0].c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav2Char(wav_list[0].c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[0];
|
||||
exit(-1);
|
||||
}
|
||||
}else{
|
||||
if (!audio.FfmpegLoad(wav_list[0].c_str(), true)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[0];
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
char* speech_buff = audio.GetSpeechChar();
|
||||
int buff_len = audio.GetSpeechLen()*2;
|
||||
|
||||
int step = 9600*2;
|
||||
bool is_final = false;
|
||||
|
||||
string final_res="";
|
||||
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
|
||||
if (sample_offset + step >= buff_len - 1) {
|
||||
step = buff_len - sample_offset;
|
||||
is_final = true;
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
|
||||
if (result)
|
||||
{
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (true) {
|
||||
// 使用原子变量获取索引并递增
|
||||
int i = wav_index.fetch_add(1);
|
||||
if (i >= wav_list.size()) {
|
||||
break;
|
||||
}
|
||||
int32_t sampling_rate_ = -1;
|
||||
funasr::Audio audio(1);
|
||||
if(is_target_file(wav_list[i].c_str(), "wav")){
|
||||
if(!audio.LoadWav2Char(wav_list[i].c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[i];
|
||||
exit(-1);
|
||||
}
|
||||
}else if(is_target_file(wav_list[i].c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav2Char(wav_list[i].c_str(), &sampling_rate_)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[i];
|
||||
exit(-1);
|
||||
}
|
||||
}else{
|
||||
if (!audio.FfmpegLoad(wav_list[i].c_str(), true)){
|
||||
LOG(ERROR)<<"Failed to load "<< wav_list[i];
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
char* speech_buff = audio.GetSpeechChar();
|
||||
int buff_len = audio.GetSpeechLen()*2;
|
||||
|
||||
int step = 9600*2;
|
||||
bool is_final = false;
|
||||
|
||||
string final_res="";
|
||||
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
|
||||
if (sample_offset + step >= buff_len - 1) {
|
||||
step = buff_len - sample_offset;
|
||||
is_final = true;
|
||||
} else {
|
||||
is_final = false;
|
||||
}
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
|
||||
gettimeofday(&end, NULL);
|
||||
seconds = (end.tv_sec - start.tv_sec);
|
||||
long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
n_total_time += taking_micros;
|
||||
|
||||
if (result)
|
||||
{
|
||||
string msg = FunASRGetResult(result, 0);
|
||||
final_res += msg;
|
||||
LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg;
|
||||
float snippet_time = FunASRGetRetSnippetTime(result);
|
||||
n_total_length += snippet_time;
|
||||
FunASRFreeResult(result);
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG(ERROR) << ("No return data!\n");
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "Thread: " << this_thread::get_id() << ", Final results " << wav_ids[i] << " : " << final_res;
|
||||
|
||||
}
|
||||
{
|
||||
lock_guard<mutex> guard(mtx);
|
||||
*total_length += n_total_length;
|
||||
if(*total_time < n_total_time){
|
||||
*total_time = n_total_time;
|
||||
}
|
||||
}
|
||||
FunASRUninit(online_handle);
|
||||
}
|
||||
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
|
||||
{
|
||||
if (value_arg.isSet()){
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO)<< key << " : " << value_arg.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[])
|
||||
{
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-onnx-online-rtf", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
|
||||
TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
|
||||
TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
|
||||
|
||||
cmd.add(model_dir);
|
||||
cmd.add(quantize);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_quant);
|
||||
cmd.add(wav_path);
|
||||
cmd.add(thread_num);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(model_dir, MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
GetValue(wav_path, WAV_PATH, model_path);
|
||||
|
||||
struct timeval start, end;
|
||||
gettimeofday(&start, NULL);
|
||||
FUNASR_HANDLE asr_handle=FunASRInit(model_path, 1, ASR_ONLINE);
|
||||
|
||||
if (!asr_handle)
|
||||
{
|
||||
LOG(ERROR) << "FunASR init failed";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
gettimeofday(&end, NULL);
|
||||
long seconds = (end.tv_sec - start.tv_sec);
|
||||
long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
|
||||
LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
|
||||
|
||||
// read wav_path
|
||||
vector<string> wav_list;
|
||||
vector<string> wav_ids;
|
||||
string default_id = "wav_default_id";
|
||||
string wav_path_ = model_path.at(WAV_PATH);
|
||||
if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
|
||||
wav_list.emplace_back(wav_path_);
|
||||
wav_ids.emplace_back(default_id);
|
||||
}
|
||||
else if(is_target_file(wav_path_, "scp")){
|
||||
ifstream in(wav_path_);
|
||||
if (!in.is_open()) {
|
||||
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
|
||||
return 0;
|
||||
}
|
||||
string line;
|
||||
while(getline(in, line))
|
||||
{
|
||||
istringstream iss(line);
|
||||
string column1, column2;
|
||||
iss >> column1 >> column2;
|
||||
wav_list.emplace_back(column2);
|
||||
wav_ids.emplace_back(column1);
|
||||
}
|
||||
in.close();
|
||||
}else{
|
||||
LOG(ERROR)<<"Please check the wav extension!";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// 多线程测试
|
||||
float total_length = 0.0f;
|
||||
long total_time = 0;
|
||||
std::vector<std::thread> threads;
|
||||
|
||||
int rtf_threds = thread_num.getValue();
|
||||
for (int i = 0; i < rtf_threds; i++)
|
||||
{
|
||||
threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i));
|
||||
}
|
||||
|
||||
for (auto& thread : threads)
|
||||
{
|
||||
thread.join();
|
||||
}
|
||||
|
||||
LOG(INFO) << "total_time_wav " << (long)(total_length * 1000) << " ms";
|
||||
LOG(INFO) << "total_time_comput " << total_time / 1000 << " ms";
|
||||
LOG(INFO) << "total_rtf " << (double)total_time/ (total_length*1000000);
|
||||
LOG(INFO) << "speedup " << 1.0/((double)total_time/ (total_length*1000000));
|
||||
|
||||
FunASRUninit(asr_handle);
|
||||
return 0;
|
||||
}
|
||||
@ -159,7 +159,7 @@ int main(int argc, char *argv[])
|
||||
char* speech_buff = audio.GetSpeechChar();
|
||||
int buff_len = audio.GetSpeechLen()*2;
|
||||
|
||||
int step = 3200;
|
||||
int step = 800*2;
|
||||
bool is_final = false;
|
||||
|
||||
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
#include <stdint.h>
|
||||
#include "vad-model.h"
|
||||
#include "offline-stream.h"
|
||||
#include "com-define.h"
|
||||
|
||||
#ifndef WAV_HEADER_SIZE
|
||||
#define WAV_HEADER_SIZE 44
|
||||
@ -17,11 +18,13 @@ class AudioFrame {
|
||||
private:
|
||||
int start;
|
||||
int end;
|
||||
int len;
|
||||
|
||||
|
||||
public:
|
||||
AudioFrame();
|
||||
AudioFrame(int len);
|
||||
AudioFrame(const AudioFrame &other);
|
||||
AudioFrame(int start, int end, bool is_final);
|
||||
|
||||
~AudioFrame();
|
||||
int SetStart(int val);
|
||||
@ -29,6 +32,10 @@ class AudioFrame {
|
||||
int GetStart();
|
||||
int GetLen();
|
||||
int Disp();
|
||||
// 2pass
|
||||
bool is_final = false;
|
||||
float* data = nullptr;
|
||||
int len;
|
||||
};
|
||||
|
||||
class Audio {
|
||||
@ -38,10 +45,11 @@ class Audio {
|
||||
char* speech_char=nullptr;
|
||||
int speech_len;
|
||||
int speech_align_len;
|
||||
int offset;
|
||||
float align_size;
|
||||
int data_type;
|
||||
queue<AudioFrame *> frame_queue;
|
||||
queue<AudioFrame *> asr_online_queue;
|
||||
queue<AudioFrame *> asr_offline_queue;
|
||||
|
||||
public:
|
||||
Audio(int data_type);
|
||||
@ -56,17 +64,35 @@ class Audio {
|
||||
bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
|
||||
bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
|
||||
bool LoadOthers2Char(const char* filename);
|
||||
bool FfmpegLoad(const char *filename);
|
||||
bool FfmpegLoad(const char *filename, bool copy2char=false);
|
||||
bool FfmpegLoad(const char* buf, int n_file_len);
|
||||
int FetchChunck(float *&dout, int len);
|
||||
int FetchChunck(AudioFrame *&frame);
|
||||
int FetchTpass(AudioFrame *&frame);
|
||||
int Fetch(float *&dout, int &len, int &flag);
|
||||
void Padding();
|
||||
void Split(OfflineStream* offline_streamj);
|
||||
void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
|
||||
void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
|
||||
float GetTimeLen();
|
||||
int GetQueueSize() { return (int)frame_queue.size(); }
|
||||
char* GetSpeechChar(){return speech_char;}
|
||||
int GetSpeechLen(){return speech_len;}
|
||||
|
||||
// 2pass
|
||||
vector<float> all_samples;
|
||||
int offset = 0;
|
||||
int speech_start=-1, speech_end=0;
|
||||
int speech_offline_start=-1;
|
||||
|
||||
int seg_sample = MODEL_SAMPLE_RATE/1000;
|
||||
bool LoadPcmwavOnline(const char* buf, int n_file_len, int32_t* sampling_rate);
|
||||
void ResetIndex(){
|
||||
speech_start=-1;
|
||||
speech_end=0;
|
||||
speech_offline_start=-1;
|
||||
offset = 0;
|
||||
all_samples.clear();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -13,11 +13,14 @@ namespace funasr {
|
||||
|
||||
// parser option
|
||||
#define MODEL_DIR "model-dir"
|
||||
#define OFFLINE_MODEL_DIR "model-dir"
|
||||
#define ONLINE_MODEL_DIR "online-model-dir"
|
||||
#define VAD_DIR "vad-dir"
|
||||
#define PUNC_DIR "punc-dir"
|
||||
#define QUANTIZE "quantize"
|
||||
#define VAD_QUANT "vad-quant"
|
||||
#define PUNC_QUANT "punc-quant"
|
||||
#define ASR_MODE "mode"
|
||||
|
||||
#define WAV_PATH "wav-path"
|
||||
#define WAV_SCP "wav-scp"
|
||||
@ -42,6 +45,11 @@ namespace funasr {
|
||||
#define AM_CONFIG_NAME "config.yaml"
|
||||
#define PUNC_CONFIG_NAME "punc.yaml"
|
||||
|
||||
#define ENCODER_NAME "model.onnx"
|
||||
#define QUANT_ENCODER_NAME "model_quant.onnx"
|
||||
#define DECODER_NAME "decoder.onnx"
|
||||
#define QUANT_DECODER_NAME "decoder_quant.onnx"
|
||||
|
||||
// vad
|
||||
#ifndef VAD_SILENCE_DURATION
|
||||
#define VAD_SILENCE_DURATION 800
|
||||
@ -63,6 +71,19 @@ namespace funasr {
|
||||
#define VAD_LFR_N 1
|
||||
#endif
|
||||
|
||||
// asr
|
||||
#ifndef PARA_LFR_M
|
||||
#define PARA_LFR_M 7
|
||||
#endif
|
||||
|
||||
#ifndef PARA_LFR_N
|
||||
#define PARA_LFR_N 6
|
||||
#endif
|
||||
|
||||
#ifndef ONLINE_STEP
|
||||
#define ONLINE_STEP 9600
|
||||
#endif
|
||||
|
||||
// punc
|
||||
#define UNK_CHAR "<unk>"
|
||||
#define TOKEN_LEN 20
|
||||
|
||||
@ -46,21 +46,29 @@ typedef enum {
|
||||
FUNASR_MODEL_PARAFORMER = 3,
|
||||
}FUNASR_MODEL_TYPE;
|
||||
|
||||
typedef enum {
|
||||
ASR_OFFLINE=0,
|
||||
ASR_ONLINE=1,
|
||||
ASR_TWO_PASS=2,
|
||||
}ASR_TYPE;
|
||||
|
||||
typedef enum {
|
||||
PUNC_OFFLINE=0,
|
||||
PUNC_ONLINE=1,
|
||||
}PUNC_TYPE;
|
||||
|
||||
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
|
||||
|
||||
|
||||
// ASR
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type=ASR_OFFLINE);
|
||||
_FUNASRAPI FUNASR_HANDLE FunASROnlineInit(FUNASR_HANDLE asr_handle, std::vector<int> chunk_size={5,10,5});
|
||||
// buffer
|
||||
_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000, std::string wav_format="pcm");
|
||||
_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm");
|
||||
// file, support wav & pcm
|
||||
_FUNASRAPI FUNASR_RESULT FunASRInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
|
||||
|
||||
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
|
||||
_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
|
||||
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
|
||||
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
|
||||
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
|
||||
@ -94,6 +102,14 @@ _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char*
|
||||
_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
|
||||
_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle);
|
||||
|
||||
//2passStream
|
||||
_FUNASRAPI FUNASR_HANDLE FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
_FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size={5,10,5});
|
||||
// buffer
|
||||
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS);
|
||||
_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle);
|
||||
_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
}
|
||||
|
||||
@ -4,17 +4,21 @@
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "funasrruntime.h"
|
||||
namespace funasr {
|
||||
class Model {
|
||||
public:
|
||||
virtual ~Model(){};
|
||||
virtual void Reset() = 0;
|
||||
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num)=0;
|
||||
virtual std::string ForwardChunk(float *din, int len, int flag) = 0;
|
||||
virtual std::string Forward(float *din, int len, int flag) = 0;
|
||||
virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
|
||||
virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
|
||||
virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
|
||||
virtual std::string Forward(float *din, int len, bool input_finished){return "";};
|
||||
virtual std::string Rescoring() = 0;
|
||||
};
|
||||
|
||||
Model *CreateModel(std::map<std::string, std::string>& model_path,int thread_num=1);
|
||||
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
|
||||
Model *CreateModel(void* asr_handle, std::vector<int> chunk_size);
|
||||
|
||||
} // namespace funasr
|
||||
#endif
|
||||
|
||||
@ -14,9 +14,9 @@ class OfflineStream {
|
||||
OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
~OfflineStream(){};
|
||||
|
||||
std::unique_ptr<VadModel> vad_handle;
|
||||
std::unique_ptr<Model> asr_handle;
|
||||
std::unique_ptr<PuncModel> punc_handle;
|
||||
std::unique_ptr<VadModel> vad_handle= nullptr;
|
||||
std::unique_ptr<Model> asr_handle= nullptr;
|
||||
std::unique_ptr<PuncModel> punc_handle= nullptr;
|
||||
bool UseVad(){return use_vad;};
|
||||
bool UsePunc(){return use_punc;};
|
||||
|
||||
|
||||
20
funasr/runtime/onnxruntime/include/tpass-online-stream.h
Normal file
20
funasr/runtime/onnxruntime/include/tpass-online-stream.h
Normal file
@ -0,0 +1,20 @@
|
||||
#ifndef TPASS_ONLINE_STREAM_H
|
||||
#define TPASS_ONLINE_STREAM_H
|
||||
|
||||
#include <memory>
|
||||
#include "tpass-stream.h"
|
||||
#include "model.h"
|
||||
#include "vad-model.h"
|
||||
|
||||
namespace funasr {
|
||||
class TpassOnlineStream {
|
||||
public:
|
||||
TpassOnlineStream(TpassStream* tpass_stream, std::vector<int> chunk_size);
|
||||
~TpassOnlineStream(){};
|
||||
|
||||
std::unique_ptr<VadModel> vad_online_handle = nullptr;
|
||||
std::unique_ptr<Model> asr_online_handle = nullptr;
|
||||
};
|
||||
TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size);
|
||||
} // namespace funasr
|
||||
#endif
|
||||
31
funasr/runtime/onnxruntime/include/tpass-stream.h
Normal file
31
funasr/runtime/onnxruntime/include/tpass-stream.h
Normal file
@ -0,0 +1,31 @@
|
||||
#ifndef TPASS_STREAM_H
|
||||
#define TPASS_STREAM_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "model.h"
|
||||
#include "punc-model.h"
|
||||
#include "vad-model.h"
|
||||
|
||||
namespace funasr {
|
||||
class TpassStream {
|
||||
public:
|
||||
TpassStream(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
~TpassStream(){};
|
||||
|
||||
// std::unique_ptr<VadModel> vad_handle = nullptr;
|
||||
std::unique_ptr<VadModel> vad_handle = nullptr;
|
||||
std::unique_ptr<Model> asr_handle = nullptr;
|
||||
std::unique_ptr<PuncModel> punc_online_handle = nullptr;
|
||||
bool UseVad(){return use_vad;};
|
||||
bool UsePunc(){return use_punc;};
|
||||
|
||||
private:
|
||||
bool use_vad=false;
|
||||
bool use_punc=false;
|
||||
};
|
||||
|
||||
TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num=1);
|
||||
} // namespace funasr
|
||||
#endif
|
||||
@ -132,40 +132,54 @@ class AudioWindow {
|
||||
};
|
||||
};
|
||||
|
||||
AudioFrame::AudioFrame(){};
|
||||
AudioFrame::AudioFrame(){}
|
||||
AudioFrame::AudioFrame(int len) : len(len)
|
||||
{
|
||||
start = 0;
|
||||
};
|
||||
AudioFrame::~AudioFrame(){};
|
||||
}
|
||||
AudioFrame::AudioFrame(const AudioFrame &other)
|
||||
{
|
||||
start = other.start;
|
||||
end = other.end;
|
||||
len = other.len;
|
||||
is_final = other.is_final;
|
||||
}
|
||||
AudioFrame::AudioFrame(int start, int end, bool is_final):start(start),end(end),is_final(is_final){
|
||||
len = end - start;
|
||||
}
|
||||
AudioFrame::~AudioFrame(){
|
||||
if(data != NULL){
|
||||
free(data);
|
||||
}
|
||||
}
|
||||
int AudioFrame::SetStart(int val)
|
||||
{
|
||||
start = val < 0 ? 0 : val;
|
||||
return start;
|
||||
};
|
||||
}
|
||||
|
||||
int AudioFrame::SetEnd(int val)
|
||||
{
|
||||
end = val;
|
||||
len = end - start;
|
||||
return end;
|
||||
};
|
||||
}
|
||||
|
||||
int AudioFrame::GetStart()
|
||||
{
|
||||
return start;
|
||||
};
|
||||
}
|
||||
|
||||
int AudioFrame::GetLen()
|
||||
{
|
||||
return len;
|
||||
};
|
||||
}
|
||||
|
||||
int AudioFrame::Disp()
|
||||
{
|
||||
LOG(ERROR) << "Not imp!!!!";
|
||||
return 0;
|
||||
};
|
||||
}
|
||||
|
||||
Audio::Audio(int data_type) : data_type(data_type)
|
||||
{
|
||||
@ -230,7 +244,7 @@ void Audio::WavResample(int32_t sampling_rate, const float *waveform,
|
||||
copy(samples.begin(), samples.end(), speech_data);
|
||||
}
|
||||
|
||||
bool Audio::FfmpegLoad(const char *filename){
|
||||
bool Audio::FfmpegLoad(const char *filename, bool copy2char){
|
||||
// from file
|
||||
AVFormatContext* formatContext = avformat_alloc_context();
|
||||
if (avformat_open_input(&formatContext, filename, NULL, NULL) != 0) {
|
||||
@ -353,8 +367,17 @@ bool Audio::FfmpegLoad(const char *filename){
|
||||
if (speech_buff != NULL) {
|
||||
free(speech_buff);
|
||||
}
|
||||
if (speech_char != NULL) {
|
||||
free(speech_char);
|
||||
}
|
||||
offset = 0;
|
||||
|
||||
if(copy2char){
|
||||
speech_char = (char *)malloc(resampled_buffers.size());
|
||||
memset(speech_char, 0, resampled_buffers.size());
|
||||
memcpy((void*)speech_char, (const void*)resampled_buffers.data(), resampled_buffers.size());
|
||||
}
|
||||
|
||||
speech_len = (resampled_buffers.size()) / 2;
|
||||
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
|
||||
if (speech_buff)
|
||||
@ -762,6 +785,55 @@ bool Audio::LoadPcmwav(const char* buf, int n_buf_len, int32_t* sampling_rate)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Audio::LoadPcmwavOnline(const char* buf, int n_buf_len, int32_t* sampling_rate)
|
||||
{
|
||||
if (speech_data != NULL) {
|
||||
free(speech_data);
|
||||
}
|
||||
if (speech_buff != NULL) {
|
||||
free(speech_buff);
|
||||
}
|
||||
if (speech_char != NULL) {
|
||||
free(speech_char);
|
||||
}
|
||||
|
||||
speech_len = n_buf_len / 2;
|
||||
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
|
||||
if (speech_buff)
|
||||
{
|
||||
memset(speech_buff, 0, sizeof(int16_t) * speech_len);
|
||||
memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t));
|
||||
|
||||
speech_data = (float*)malloc(sizeof(float) * speech_len);
|
||||
memset(speech_data, 0, sizeof(float) * speech_len);
|
||||
|
||||
float scale = 1;
|
||||
if (data_type == 1) {
|
||||
scale = 32768;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != speech_len; ++i) {
|
||||
speech_data[i] = (float)speech_buff[i] / scale;
|
||||
}
|
||||
|
||||
//resample
|
||||
if(*sampling_rate != MODEL_SAMPLE_RATE){
|
||||
WavResample(*sampling_rate, speech_data, speech_len);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != speech_len; ++i) {
|
||||
all_samples.emplace_back(speech_data[i]);
|
||||
}
|
||||
|
||||
AudioFrame* frame = new AudioFrame(speech_len);
|
||||
frame_queue.push(frame);
|
||||
return true;
|
||||
|
||||
}
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
|
||||
{
|
||||
if (speech_data != NULL) {
|
||||
@ -870,24 +942,25 @@ bool Audio::LoadOthers2Char(const char* filename)
|
||||
return true;
|
||||
}
|
||||
|
||||
int Audio::FetchChunck(float *&dout, int len)
|
||||
int Audio::FetchTpass(AudioFrame *&frame)
|
||||
{
|
||||
if (offset >= speech_align_len) {
|
||||
dout = NULL;
|
||||
return S_ERR;
|
||||
} else if (offset == speech_align_len - len) {
|
||||
dout = speech_data + offset;
|
||||
offset = speech_align_len;
|
||||
// 临时解决
|
||||
AudioFrame *frame = frame_queue.front();
|
||||
frame_queue.pop();
|
||||
delete frame;
|
||||
|
||||
return S_END;
|
||||
if (asr_offline_queue.size() > 0) {
|
||||
frame = asr_offline_queue.front();
|
||||
asr_offline_queue.pop();
|
||||
return 1;
|
||||
} else {
|
||||
dout = speech_data + offset;
|
||||
offset += len;
|
||||
return S_MIDDLE;
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
int Audio::FetchChunck(AudioFrame *&frame)
|
||||
{
|
||||
if (asr_online_queue.size() > 0) {
|
||||
frame = asr_online_queue.front();
|
||||
asr_online_queue.pop();
|
||||
return 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
@ -956,7 +1029,6 @@ void Audio::Split(OfflineStream* offline_stream)
|
||||
|
||||
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
|
||||
vector<std::vector<int>> vad_segments = (offline_stream->vad_handle)->Infer(pcm_data);
|
||||
int seg_sample = MODEL_SAMPLE_RATE/1000;
|
||||
for(vector<int> segment:vad_segments)
|
||||
{
|
||||
frame = new AudioFrame();
|
||||
@ -969,7 +1041,6 @@ void Audio::Split(OfflineStream* offline_stream)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
|
||||
{
|
||||
AudioFrame *frame;
|
||||
@ -984,4 +1055,161 @@ void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, boo
|
||||
vad_segments = vad_obj->Infer(pcm_data, input_finished);
|
||||
}
|
||||
|
||||
// 2pass
|
||||
void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYPE asr_mode)
|
||||
{
|
||||
AudioFrame *frame;
|
||||
|
||||
frame = frame_queue.front();
|
||||
frame_queue.pop();
|
||||
int sp_len = frame->GetLen();
|
||||
delete frame;
|
||||
frame = NULL;
|
||||
|
||||
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
|
||||
vector<std::vector<int>> vad_segments = vad_obj->Infer(pcm_data, input_finished);
|
||||
|
||||
speech_end += sp_len/seg_sample;
|
||||
if(vad_segments.size() == 0){
|
||||
if(speech_start != -1){
|
||||
int start = speech_start*seg_sample;
|
||||
int end = speech_end*seg_sample;
|
||||
int buff_len = end-start;
|
||||
int step = chunk_len;
|
||||
|
||||
if(asr_mode != ASR_OFFLINE){
|
||||
if(buff_len >= step){
|
||||
frame = new AudioFrame(step);
|
||||
frame->data = (float*)malloc(sizeof(float) * step);
|
||||
memcpy(frame->data, all_samples.data()+start-offset, step*sizeof(float));
|
||||
asr_online_queue.push(frame);
|
||||
frame = NULL;
|
||||
speech_start += step/seg_sample;
|
||||
}
|
||||
}
|
||||
}
|
||||
}else{
|
||||
for(auto vad_segment: vad_segments){
|
||||
int speech_start_i=-1, speech_end_i=-1;
|
||||
if(vad_segment[0] != -1){
|
||||
speech_start_i = vad_segment[0];
|
||||
}
|
||||
if(vad_segment[1] != -1){
|
||||
speech_end_i = vad_segment[1];
|
||||
}
|
||||
|
||||
// [1, 100]
|
||||
if(speech_start_i != -1 && speech_end_i != -1){
|
||||
int start = speech_start_i*seg_sample;
|
||||
int end = speech_end_i*seg_sample;
|
||||
|
||||
if(asr_mode != ASR_OFFLINE){
|
||||
frame = new AudioFrame(end-start);
|
||||
frame->is_final = true;
|
||||
frame->data = (float*)malloc(sizeof(float) * (end-start));
|
||||
memcpy(frame->data, all_samples.data()+start-offset, (end-start)*sizeof(float));
|
||||
asr_online_queue.push(frame);
|
||||
frame = NULL;
|
||||
}
|
||||
|
||||
if(asr_mode != ASR_ONLINE){
|
||||
frame = new AudioFrame(end-start);
|
||||
frame->is_final = true;
|
||||
frame->data = (float*)malloc(sizeof(float) * (end-start));
|
||||
memcpy(frame->data, all_samples.data()+start-offset, (end-start)*sizeof(float));
|
||||
asr_offline_queue.push(frame);
|
||||
frame = NULL;
|
||||
}
|
||||
|
||||
speech_start = -1;
|
||||
speech_offline_start = -1;
|
||||
// [70, -1]
|
||||
}else if(speech_start_i != -1){
|
||||
speech_start = speech_start_i;
|
||||
speech_offline_start = speech_start_i;
|
||||
|
||||
int start = speech_start*seg_sample;
|
||||
int end = speech_end*seg_sample;
|
||||
int buff_len = end-start;
|
||||
int step = chunk_len;
|
||||
|
||||
if(asr_mode != ASR_OFFLINE){
|
||||
if(buff_len >= step){
|
||||
frame = new AudioFrame(step);
|
||||
frame->data = (float*)malloc(sizeof(float) * step);
|
||||
memcpy(frame->data, all_samples.data()+start-offset, step*sizeof(float));
|
||||
asr_online_queue.push(frame);
|
||||
frame = NULL;
|
||||
speech_start += step/seg_sample;
|
||||
}
|
||||
}
|
||||
|
||||
}else if(speech_end_i != -1){ // [-1,100]
|
||||
if(speech_start == -1 or speech_offline_start == -1){
|
||||
LOG(ERROR) <<"Vad start is null while vad end is available." ;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int start = speech_start*seg_sample;
|
||||
int offline_start = speech_offline_start*seg_sample;
|
||||
int end = speech_end_i*seg_sample;
|
||||
int buff_len = end-start;
|
||||
int step = chunk_len;
|
||||
|
||||
if(asr_mode != ASR_ONLINE){
|
||||
frame = new AudioFrame(end-offline_start);
|
||||
frame->is_final = true;
|
||||
frame->data = (float*)malloc(sizeof(float) * (end-offline_start));
|
||||
memcpy(frame->data, all_samples.data()+offline_start-offset, (end-offline_start)*sizeof(float));
|
||||
asr_offline_queue.push(frame);
|
||||
frame = NULL;
|
||||
}
|
||||
|
||||
if(asr_mode != ASR_OFFLINE){
|
||||
if(buff_len > 0){
|
||||
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
|
||||
bool is_final = false;
|
||||
if (sample_offset + step >= buff_len - 1) {
|
||||
step = buff_len - sample_offset;
|
||||
is_final = true;
|
||||
}
|
||||
frame = new AudioFrame(step);
|
||||
frame->is_final = is_final;
|
||||
frame->data = (float*)malloc(sizeof(float) * step);
|
||||
memcpy(frame->data, all_samples.data()+start-offset+sample_offset, step*sizeof(float));
|
||||
asr_online_queue.push(frame);
|
||||
frame = NULL;
|
||||
}
|
||||
}else{
|
||||
frame = new AudioFrame(0);
|
||||
frame->is_final = true;
|
||||
asr_online_queue.push(frame);
|
||||
frame = NULL;
|
||||
}
|
||||
}
|
||||
speech_start = -1;
|
||||
speech_offline_start = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// erase all_samples
|
||||
int vector_cache = MODEL_SAMPLE_RATE*2;
|
||||
if(speech_offline_start == -1){
|
||||
if(all_samples.size() > vector_cache){
|
||||
int erase_num = all_samples.size() - vector_cache;
|
||||
all_samples.erase(all_samples.begin(), all_samples.begin()+erase_num);
|
||||
offset += erase_num;
|
||||
}
|
||||
}else{
|
||||
int offline_start = speech_offline_start*seg_sample;
|
||||
if(offline_start-offset > vector_cache){
|
||||
int erase_num = offline_start-offset - vector_cache;
|
||||
all_samples.erase(all_samples.begin(), all_samples.begin()+erase_num);
|
||||
offset += erase_num;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
} // namespace funasr
|
||||
@ -5,7 +5,8 @@ namespace funasr {
|
||||
typedef struct
|
||||
{
|
||||
std::string msg;
|
||||
float snippet_time;
|
||||
std::string tpass_msg;
|
||||
float snippet_time;
|
||||
}FUNASR_RECOG_RESULT;
|
||||
|
||||
typedef struct
|
||||
|
||||
@ -181,11 +181,12 @@ vector<int> CTTransformerOnline::Infer(vector<int32_t> input_data, int nCacheSiz
|
||||
text_lengths_dim.size()); //, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
|
||||
|
||||
//vad_mask
|
||||
vector<float> arVadMask,arSubMask;
|
||||
// vector<float> arVadMask,arSubMask;
|
||||
vector<float> arVadMask;
|
||||
int nTextLength = input_data.size();
|
||||
|
||||
VadMask(nTextLength, nCacheSize, arVadMask);
|
||||
Triangle(nTextLength, arSubMask);
|
||||
// Triangle(nTextLength, arSubMask);
|
||||
std::array<int64_t, 4> VadMask_Dim{ 1,1, nTextLength ,nTextLength };
|
||||
Ort::Value onnx_vad_mask = Ort::Value::CreateTensor<float>(
|
||||
m_memoryInfo,
|
||||
@ -198,8 +199,8 @@ vector<int> CTTransformerOnline::Infer(vector<int32_t> input_data, int nCacheSiz
|
||||
std::array<int64_t, 4> SubMask_Dim{ 1,1, nTextLength ,nTextLength };
|
||||
Ort::Value onnx_sub_mask = Ort::Value::CreateTensor<float>(
|
||||
m_memoryInfo,
|
||||
arSubMask.data(),
|
||||
arSubMask.size() ,
|
||||
arVadMask.data(),
|
||||
arVadMask.size(),
|
||||
SubMask_Dim.data(),
|
||||
SubMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &
|
||||
int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
|
||||
int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
|
||||
int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished);
|
||||
int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
|
||||
int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
|
||||
reserve_waveforms_.clear();
|
||||
reserve_waveforms_.insert(reserve_waveforms_.begin(),
|
||||
waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
|
||||
@ -86,7 +86,7 @@ void FsmnVadOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &
|
||||
int FsmnVadOnline::OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished) {
|
||||
vector<vector<float>> out_feats;
|
||||
int T = vad_feats.size();
|
||||
int T_lrf = ceil((T - (lfr_m - 1) / 2) / lfr_n);
|
||||
int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n);
|
||||
int lfr_splice_frame_idxs = T_lrf;
|
||||
vector<float> p;
|
||||
for (int i = 0; i < T_lrf; i++) {
|
||||
@ -175,6 +175,9 @@ void FsmnVadOnline::InitOnline(std::shared_ptr<Ort::Session> &vad_session,
|
||||
vad_silence_duration_ = vad_silence_duration;
|
||||
vad_max_len_ = vad_max_len;
|
||||
vad_speech_noise_thres_ = vad_speech_noise_thres;
|
||||
|
||||
// 2pass
|
||||
audio_handle = make_unique<Audio>(1);
|
||||
}
|
||||
|
||||
FsmnVadOnline::~FsmnVadOnline() {
|
||||
|
||||
@ -21,6 +21,8 @@ public:
|
||||
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
|
||||
void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
|
||||
void Reset();
|
||||
// 2pass
|
||||
std::unique_ptr<Audio> audio_handle = nullptr;
|
||||
|
||||
private:
|
||||
E2EVadModel vad_scorer = E2EVadModel();
|
||||
|
||||
@ -5,9 +5,15 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
// APIs for Init
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type)
|
||||
{
|
||||
funasr::Model* mm = funasr::CreateModel(model_path, thread_num);
|
||||
funasr::Model* mm = funasr::CreateModel(model_path, thread_num, type);
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FunASROnlineInit(FUNASR_HANDLE asr_hanlde, std::vector<int> chunk_size)
|
||||
{
|
||||
funasr::Model* mm = funasr::CreateModel(asr_hanlde, chunk_size);
|
||||
return mm;
|
||||
}
|
||||
|
||||
@ -35,8 +41,19 @@ extern "C" {
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
funasr::TpassStream* mm = funasr::CreateTpassStream(model_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
|
||||
_FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size)
|
||||
{
|
||||
return funasr::CreateTpassOnlineStream(tpass_handle, chunk_size);
|
||||
}
|
||||
|
||||
// APIs for ASR Infer
|
||||
_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format)
|
||||
_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
|
||||
{
|
||||
funasr::Model* recog_obj = (funasr::Model*)handle;
|
||||
if (!recog_obj)
|
||||
@ -57,12 +74,12 @@ extern "C" {
|
||||
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
|
||||
p_result->snippet_time = audio.GetTimeLen();
|
||||
if(p_result->snippet_time == 0){
|
||||
return p_result;
|
||||
}
|
||||
return p_result;
|
||||
}
|
||||
int n_step = 0;
|
||||
int n_total = audio.GetQueueSize();
|
||||
while (audio.Fetch(buff, len, flag) > 0) {
|
||||
string msg = recog_obj->Forward(buff, len, flag);
|
||||
string msg = recog_obj->Forward(buff, len, input_finished);
|
||||
p_result->msg += msg;
|
||||
n_step++;
|
||||
if (fn_callback)
|
||||
@ -102,7 +119,7 @@ extern "C" {
|
||||
return p_result;
|
||||
}
|
||||
while (audio.Fetch(buff, len, flag) > 0) {
|
||||
string msg = recog_obj->Forward(buff, len, flag);
|
||||
string msg = recog_obj->Forward(buff, len, true);
|
||||
p_result->msg += msg;
|
||||
n_step++;
|
||||
if (fn_callback)
|
||||
@ -230,7 +247,7 @@ extern "C" {
|
||||
int n_step = 0;
|
||||
int n_total = audio.GetQueueSize();
|
||||
while (audio.Fetch(buff, len, flag) > 0) {
|
||||
string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
|
||||
string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
|
||||
p_result->msg += msg;
|
||||
n_step++;
|
||||
if (fn_callback)
|
||||
@ -277,7 +294,7 @@ extern "C" {
|
||||
int n_step = 0;
|
||||
int n_total = audio.GetQueueSize();
|
||||
while (audio.Fetch(buff, len, flag) > 0) {
|
||||
string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
|
||||
string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
|
||||
p_result->msg+= msg;
|
||||
n_step++;
|
||||
if (fn_callback)
|
||||
@ -291,6 +308,91 @@ extern "C" {
|
||||
return p_result;
|
||||
}
|
||||
|
||||
// APIs for 2pass-stream Infer
|
||||
_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, int sampling_rate, std::string wav_format, ASR_TYPE mode)
|
||||
{
|
||||
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
|
||||
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
|
||||
if (!tpass_stream || !tpass_online_stream)
|
||||
return nullptr;
|
||||
|
||||
funasr::VadModel* vad_online_handle = (tpass_online_stream->vad_online_handle).get();
|
||||
if (!vad_online_handle)
|
||||
return nullptr;
|
||||
|
||||
funasr::Audio* audio = ((funasr::FsmnVadOnline*)vad_online_handle)->audio_handle.get();
|
||||
|
||||
funasr::Model* asr_online_handle = (tpass_online_stream->asr_online_handle).get();
|
||||
if (!asr_online_handle)
|
||||
return nullptr;
|
||||
int chunk_len = ((funasr::ParaformerOnline*)asr_online_handle)->chunk_len;
|
||||
|
||||
funasr::Model* asr_handle = (tpass_stream->asr_handle).get();
|
||||
if (!asr_handle)
|
||||
return nullptr;
|
||||
|
||||
funasr::PuncModel* punc_online_handle = (tpass_stream->punc_online_handle).get();
|
||||
if (!punc_online_handle)
|
||||
return nullptr;
|
||||
|
||||
if(wav_format == "pcm" || wav_format == "PCM"){
|
||||
if (!audio->LoadPcmwavOnline(sz_buf, n_len, &sampling_rate))
|
||||
return nullptr;
|
||||
}else{
|
||||
// if (!audio->FfmpegLoad(sz_buf, n_len))
|
||||
// return nullptr;
|
||||
LOG(ERROR) <<"Wrong wav_format: " << wav_format ;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
|
||||
p_result->snippet_time = audio->GetTimeLen();
|
||||
if(p_result->snippet_time == 0){
|
||||
return p_result;
|
||||
}
|
||||
|
||||
audio->Split(vad_online_handle, chunk_len, input_finished, mode);
|
||||
|
||||
funasr::AudioFrame* frame = NULL;
|
||||
while(audio->FetchChunck(frame) > 0){
|
||||
string msg = asr_online_handle->Forward(frame->data, frame->len, frame->is_final);
|
||||
if(mode == ASR_ONLINE){
|
||||
((funasr::ParaformerOnline*)asr_online_handle)->online_res += msg;
|
||||
if(frame->is_final){
|
||||
string online_msg = ((funasr::ParaformerOnline*)asr_online_handle)->online_res;
|
||||
string msg_punc = punc_online_handle->AddPunc(online_msg.c_str(), punc_cache[0]);
|
||||
p_result->tpass_msg = msg_punc;
|
||||
((funasr::ParaformerOnline*)asr_online_handle)->online_res = "";
|
||||
p_result->msg += msg;
|
||||
}else{
|
||||
p_result->msg += msg;
|
||||
}
|
||||
}else if(mode == ASR_TWO_PASS){
|
||||
p_result->msg += msg;
|
||||
}
|
||||
if(frame != NULL){
|
||||
delete frame;
|
||||
frame = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
while(audio->FetchTpass(frame) > 0){
|
||||
string msg = asr_handle->Forward(frame->data, frame->len, frame->is_final);
|
||||
string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[1]);
|
||||
p_result->tpass_msg = msg_punc;
|
||||
if(frame != NULL){
|
||||
delete frame;
|
||||
frame = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
if(input_finished){
|
||||
audio->ResetIndex();
|
||||
}
|
||||
|
||||
return p_result;
|
||||
}
|
||||
|
||||
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result)
|
||||
{
|
||||
if (!result)
|
||||
@ -326,6 +428,15 @@ extern "C" {
|
||||
return p_result->msg.c_str();
|
||||
}
|
||||
|
||||
_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
|
||||
{
|
||||
funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
|
||||
if(!p_result)
|
||||
return nullptr;
|
||||
|
||||
return p_result->tpass_msg.c_str();
|
||||
}
|
||||
|
||||
_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
|
||||
{
|
||||
funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result;
|
||||
@ -414,6 +525,26 @@ extern "C" {
|
||||
delete offline_stream;
|
||||
}
|
||||
|
||||
_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle)
|
||||
{
|
||||
funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
|
||||
|
||||
if (!tpass_stream)
|
||||
return;
|
||||
|
||||
delete tpass_stream;
|
||||
}
|
||||
|
||||
_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle)
|
||||
{
|
||||
funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)handle;
|
||||
|
||||
if (!tpass_online_stream)
|
||||
return;
|
||||
|
||||
delete tpass_online_stream;
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
}
|
||||
|
||||
@ -1,22 +1,55 @@
|
||||
#include "precomp.h"
|
||||
|
||||
namespace funasr {
|
||||
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type)
|
||||
{
|
||||
string am_model_path;
|
||||
string am_cmvn_path;
|
||||
string am_config_path;
|
||||
// offline
|
||||
if(type == ASR_OFFLINE){
|
||||
string am_model_path;
|
||||
string am_cmvn_path;
|
||||
string am_config_path;
|
||||
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
|
||||
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
|
||||
|
||||
Model *mm;
|
||||
mm = new Paraformer();
|
||||
mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
|
||||
return mm;
|
||||
}else if(type == ASR_ONLINE){
|
||||
// online
|
||||
string en_model_path;
|
||||
string de_model_path;
|
||||
string am_cmvn_path;
|
||||
string am_config_path;
|
||||
|
||||
en_model_path = PathAppend(model_path.at(MODEL_DIR), ENCODER_NAME);
|
||||
de_model_path = PathAppend(model_path.at(MODEL_DIR), DECODER_NAME);
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
en_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_ENCODER_NAME);
|
||||
de_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_DECODER_NAME);
|
||||
}
|
||||
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
|
||||
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
|
||||
|
||||
Model *mm;
|
||||
mm = new Paraformer();
|
||||
mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
|
||||
return mm;
|
||||
}else{
|
||||
LOG(ERROR)<<"Wrong ASR_TYPE : " << type;
|
||||
exit(-1);
|
||||
}
|
||||
am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
|
||||
am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
|
||||
}
|
||||
|
||||
Model *mm;
|
||||
mm = new Paraformer();
|
||||
mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
|
||||
Model *CreateModel(void* asr_handle, std::vector<int> chunk_size)
|
||||
{
|
||||
Model* mm;
|
||||
mm = new ParaformerOnline((Paraformer*)asr_handle, chunk_size);
|
||||
return mm;
|
||||
}
|
||||
|
||||
|
||||
551
funasr/runtime/onnxruntime/src/paraformer-online.cpp
Normal file
551
funasr/runtime/onnxruntime/src/paraformer-online.cpp
Normal file
@ -0,0 +1,551 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace funasr {
|
||||
|
||||
ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
|
||||
:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
|
||||
InitOnline(
|
||||
para_handle_->fbank_opts_,
|
||||
para_handle_->encoder_session_,
|
||||
para_handle_->decoder_session_,
|
||||
para_handle_->en_szInputNames_,
|
||||
para_handle_->en_szOutputNames_,
|
||||
para_handle_->de_szInputNames_,
|
||||
para_handle_->de_szOutputNames_,
|
||||
para_handle_->means_list_,
|
||||
para_handle_->vars_list_);
|
||||
InitCache();
|
||||
}
|
||||
|
||||
void ParaformerOnline::InitOnline(
|
||||
knf::FbankOptions &fbank_opts,
|
||||
std::shared_ptr<Ort::Session> &encoder_session,
|
||||
std::shared_ptr<Ort::Session> &decoder_session,
|
||||
vector<const char*> &en_szInputNames,
|
||||
vector<const char*> &en_szOutputNames,
|
||||
vector<const char*> &de_szInputNames,
|
||||
vector<const char*> &de_szOutputNames,
|
||||
vector<float> &means_list,
|
||||
vector<float> &vars_list){
|
||||
fbank_opts_ = fbank_opts;
|
||||
encoder_session_ = encoder_session;
|
||||
decoder_session_ = decoder_session;
|
||||
en_szInputNames_ = en_szInputNames;
|
||||
en_szOutputNames_ = en_szOutputNames;
|
||||
de_szInputNames_ = de_szInputNames;
|
||||
de_szOutputNames_ = de_szOutputNames;
|
||||
means_list_ = means_list;
|
||||
vars_list_ = vars_list;
|
||||
|
||||
frame_length = para_handle_->frame_length;
|
||||
frame_shift = para_handle_->frame_shift;
|
||||
n_mels = para_handle_->n_mels;
|
||||
lfr_m = para_handle_->lfr_m;
|
||||
lfr_n = para_handle_->lfr_n;
|
||||
encoder_size = para_handle_->encoder_size;
|
||||
fsmn_layers = para_handle_->fsmn_layers;
|
||||
fsmn_lorder = para_handle_->fsmn_lorder;
|
||||
fsmn_dims = para_handle_->fsmn_dims;
|
||||
cif_threshold = para_handle_->cif_threshold;
|
||||
tail_alphas = para_handle_->tail_alphas;
|
||||
|
||||
// other vars
|
||||
sqrt_factor = std::sqrt(encoder_size);
|
||||
for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
|
||||
fsmn_init_cache_.emplace_back(0);
|
||||
}
|
||||
chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
|
||||
}
|
||||
|
||||
void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
|
||||
std::vector<float> &waves) {
|
||||
knf::OnlineFbank fbank(fbank_opts_);
|
||||
// cache merge
|
||||
waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
|
||||
int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
|
||||
// Send the audio after the last frame shift position to the cache
|
||||
input_cache_.clear();
|
||||
input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
|
||||
if (frame_number == 0) {
|
||||
return;
|
||||
}
|
||||
// Delete audio that haven't undergone fbank processing
|
||||
waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
|
||||
|
||||
std::vector<float> buf(waves.size());
|
||||
for (int32_t i = 0; i != waves.size(); ++i) {
|
||||
buf[i] = waves[i] * 32768;
|
||||
}
|
||||
fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
|
||||
int32_t frames = fbank.NumFramesReady();
|
||||
for (int32_t i = 0; i != frames; ++i) {
|
||||
const float *frame = fbank.GetFrame(i);
|
||||
vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
|
||||
wav_feats.emplace_back(frame_vector);
|
||||
}
|
||||
}
|
||||
|
||||
void ParaformerOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &wav_feats,
|
||||
vector<float> &waves, bool input_finished) {
|
||||
FbankKaldi(sample_rate, wav_feats, waves);
|
||||
// cache deal & online lfr,cmvn
|
||||
if (wav_feats.size() > 0) {
|
||||
if (!reserve_waveforms_.empty()) {
|
||||
waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
|
||||
}
|
||||
if (lfr_splice_cache_.empty()) {
|
||||
for (int i = 0; i < (lfr_m - 1) / 2; i++) {
|
||||
lfr_splice_cache_.emplace_back(wav_feats[0]);
|
||||
}
|
||||
}
|
||||
if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
|
||||
wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
|
||||
int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
|
||||
int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
|
||||
int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
|
||||
int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
|
||||
reserve_waveforms_.clear();
|
||||
reserve_waveforms_.insert(reserve_waveforms_.begin(),
|
||||
waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
|
||||
waves.begin() + frame_from_waves * frame_shift_sample_length_);
|
||||
int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
|
||||
waves.erase(waves.begin() + sample_length, waves.end());
|
||||
} else {
|
||||
reserve_waveforms_.clear();
|
||||
reserve_waveforms_.insert(reserve_waveforms_.begin(),
|
||||
waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
|
||||
lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
|
||||
}
|
||||
} else {
|
||||
if (input_finished) {
|
||||
if (!reserve_waveforms_.empty()) {
|
||||
waves = reserve_waveforms_;
|
||||
}
|
||||
wav_feats = lfr_splice_cache_;
|
||||
OnlineLfrCmvn(wav_feats, input_finished);
|
||||
}
|
||||
}
|
||||
if(input_finished){
|
||||
ResetCache();
|
||||
}
|
||||
}
|
||||
|
||||
int ParaformerOnline::OnlineLfrCmvn(vector<vector<float>> &wav_feats, bool input_finished) {
|
||||
vector<vector<float>> out_feats;
|
||||
int T = wav_feats.size();
|
||||
int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n);
|
||||
int lfr_splice_frame_idxs = T_lrf;
|
||||
vector<float> p;
|
||||
for (int i = 0; i < T_lrf; i++) {
|
||||
if (lfr_m <= T - i * lfr_n) {
|
||||
for (int j = 0; j < lfr_m; j++) {
|
||||
p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
p.clear();
|
||||
} else {
|
||||
if (input_finished) {
|
||||
int num_padding = lfr_m - (T - i * lfr_n);
|
||||
for (int j = 0; j < (wav_feats.size() - i * lfr_n); j++) {
|
||||
p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
|
||||
}
|
||||
for (int j = 0; j < num_padding; j++) {
|
||||
p.insert(p.end(), wav_feats[wav_feats.size() - 1].begin(), wav_feats[wav_feats.size() - 1].end());
|
||||
}
|
||||
out_feats.emplace_back(p);
|
||||
} else {
|
||||
lfr_splice_frame_idxs = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
|
||||
lfr_splice_cache_.clear();
|
||||
lfr_splice_cache_.insert(lfr_splice_cache_.begin(), wav_feats.begin() + lfr_splice_frame_idxs, wav_feats.end());
|
||||
|
||||
// Apply cmvn
|
||||
for (auto &out_feat: out_feats) {
|
||||
for (int j = 0; j < means_list_.size(); j++) {
|
||||
out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
|
||||
}
|
||||
}
|
||||
wav_feats = out_feats;
|
||||
return lfr_splice_frame_idxs;
|
||||
}
|
||||
|
||||
void ParaformerOnline::GetPosEmb(std::vector<std::vector<float>> &wav_feats, int timesteps, int feat_dim)
|
||||
{
|
||||
int start_idx = start_idx_cache_;
|
||||
start_idx_cache_ += timesteps;
|
||||
int mm = start_idx_cache_;
|
||||
|
||||
int i;
|
||||
float scale = -0.0330119726594128;
|
||||
|
||||
std::vector<float> tmp(mm*feat_dim);
|
||||
|
||||
for (i = 0; i < feat_dim/2; i++) {
|
||||
float tmptime = exp(i * scale);
|
||||
int j;
|
||||
for (j = 0; j < mm; j++) {
|
||||
int sin_idx = j * feat_dim + i;
|
||||
int cos_idx = j * feat_dim + i + feat_dim/2;
|
||||
float coe = tmptime * (j + 1);
|
||||
tmp[sin_idx] = sin(coe);
|
||||
tmp[cos_idx] = cos(coe);
|
||||
}
|
||||
}
|
||||
|
||||
for (i = start_idx; i < start_idx + timesteps; i++) {
|
||||
for (int j = 0; j < feat_dim; j++) {
|
||||
wav_feats[i-start_idx][j] += tmp[i*feat_dim+j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ParaformerOnline::CifSearch(std::vector<std::vector<float>> hidden, std::vector<float> alphas, bool is_final, std::vector<std::vector<float>>& list_frame)
|
||||
{
|
||||
try{
|
||||
int hidden_size = 0;
|
||||
if(hidden.size() > 0){
|
||||
hidden_size = hidden[0].size();
|
||||
}
|
||||
// cache
|
||||
int i,j;
|
||||
int chunk_size_pre = chunk_size[0];
|
||||
for (i = 0; i < chunk_size_pre; i++)
|
||||
alphas[i] = 0.0;
|
||||
|
||||
int chunk_size_suf = std::accumulate(chunk_size.begin(), chunk_size.end()-1, 0);
|
||||
for (i = chunk_size_suf; i < alphas.size(); i++){
|
||||
alphas[i] = 0.0;
|
||||
}
|
||||
|
||||
if(hidden_cache_.size()>0){
|
||||
hidden.insert(hidden.begin(), hidden_cache_.begin(), hidden_cache_.end());
|
||||
alphas.insert(alphas.begin(), alphas_cache_.begin(), alphas_cache_.end());
|
||||
hidden_cache_.clear();
|
||||
alphas_cache_.clear();
|
||||
}
|
||||
|
||||
if (is_last_chunk) {
|
||||
std::vector<float> tail_hidden(hidden_size, 0);
|
||||
hidden.emplace_back(tail_hidden);
|
||||
alphas.emplace_back(tail_alphas);
|
||||
}
|
||||
|
||||
float intergrate = 0.0;
|
||||
int len_time = alphas.size();
|
||||
std::vector<float> frames(hidden_size, 0);
|
||||
std::vector<float> list_fire;
|
||||
|
||||
for (i = 0; i < len_time; i++) {
|
||||
float alpha = alphas[i];
|
||||
if (alpha + intergrate < cif_threshold) {
|
||||
intergrate += alpha;
|
||||
list_fire.emplace_back(intergrate);
|
||||
for (j = 0; j < hidden_size; j++) {
|
||||
frames[j] += alpha * hidden[i][j];
|
||||
}
|
||||
} else {
|
||||
for (j = 0; j < hidden_size; j++) {
|
||||
frames[j] += (cif_threshold - intergrate) * hidden[i][j];
|
||||
}
|
||||
std::vector<float> frames_cp(frames);
|
||||
list_frame.emplace_back(frames_cp);
|
||||
intergrate += alpha;
|
||||
list_fire.emplace_back(intergrate);
|
||||
intergrate -= cif_threshold;
|
||||
for (j = 0; j < hidden_size; j++) {
|
||||
frames[j] = intergrate * hidden[i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cache
|
||||
alphas_cache_.emplace_back(intergrate);
|
||||
if (intergrate > 0.0) {
|
||||
std::vector<float> hidden_cache(hidden_size, 0);
|
||||
for (i = 0; i < hidden_size; i++) {
|
||||
hidden_cache[i] = frames[i] / intergrate;
|
||||
}
|
||||
hidden_cache_.emplace_back(hidden_cache);
|
||||
} else {
|
||||
std::vector<float> frames_cp(frames);
|
||||
hidden_cache_.emplace_back(frames_cp);
|
||||
}
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
}
|
||||
}
|
||||
|
||||
void ParaformerOnline::InitCache(){
|
||||
|
||||
start_idx_cache_ = 0;
|
||||
is_first_chunk = true;
|
||||
is_last_chunk = false;
|
||||
hidden_cache_.clear();
|
||||
alphas_cache_.clear();
|
||||
feats_cache_.clear();
|
||||
decoder_onnx.clear();
|
||||
|
||||
// cif cache
|
||||
std::vector<float> hidden_cache(encoder_size, 0);
|
||||
hidden_cache_.emplace_back(hidden_cache);
|
||||
alphas_cache_.emplace_back(0);
|
||||
|
||||
// feats
|
||||
std::vector<float> feat_cache(feat_dims, 0);
|
||||
for(int i=0; i<(chunk_size[0]+chunk_size[2]); i++){
|
||||
feats_cache_.emplace_back(feat_cache);
|
||||
}
|
||||
|
||||
// fsmn cache
|
||||
#ifdef _WIN_X86
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
#else
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
#endif
|
||||
const int64_t fsmn_shape_[3] = {1, fsmn_dims, fsmn_lorder};
|
||||
for(int l=0; l<fsmn_layers; l++){
|
||||
Ort::Value onnx_fsmn_cache = Ort::Value::CreateTensor<float>(
|
||||
m_memoryInfo,
|
||||
fsmn_init_cache_.data(),
|
||||
fsmn_init_cache_.size(),
|
||||
fsmn_shape_,
|
||||
3);
|
||||
decoder_onnx.emplace_back(std::move(onnx_fsmn_cache));
|
||||
}
|
||||
};
|
||||
|
||||
void ParaformerOnline::Reset()
|
||||
{
|
||||
InitCache();
|
||||
}
|
||||
|
||||
void ParaformerOnline::ResetCache() {
|
||||
reserve_waveforms_.clear();
|
||||
input_cache_.clear();
|
||||
lfr_splice_cache_.clear();
|
||||
}
|
||||
|
||||
void ParaformerOnline::AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished){
|
||||
wav_feats.insert(wav_feats.begin(), feats_cache_.begin(), feats_cache_.end());
|
||||
if(input_finished){
|
||||
feats_cache_.clear();
|
||||
feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0], wav_feats.end());
|
||||
if(!is_last_chunk){
|
||||
int padding_length = std::accumulate(chunk_size.begin(), chunk_size.end(), 0) - wav_feats.size();
|
||||
std::vector<float> tmp(feat_dims, 0);
|
||||
for(int i=0; i<padding_length; i++){
|
||||
wav_feats.emplace_back(feat_dims);
|
||||
}
|
||||
}
|
||||
}else{
|
||||
feats_cache_.clear();
|
||||
feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0]-chunk_size[2], wav_feats.end());
|
||||
}
|
||||
}
|
||||
|
||||
string ParaformerOnline::ForwardChunk(std::vector<std::vector<float>> &chunk_feats, bool input_finished)
|
||||
{
|
||||
string result;
|
||||
try{
|
||||
int32_t num_frames = chunk_feats.size();
|
||||
|
||||
#ifdef _WIN_X86
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
|
||||
#else
|
||||
Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
||||
#endif
|
||||
const int64_t input_shape_[3] = {1, num_frames, feat_dims};
|
||||
std::vector<float> wav_feats;
|
||||
for (const auto &chunk_feat: chunk_feats) {
|
||||
wav_feats.insert(wav_feats.end(), chunk_feat.begin(), chunk_feat.end());
|
||||
}
|
||||
Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(
|
||||
m_memoryInfo,
|
||||
wav_feats.data(),
|
||||
wav_feats.size(),
|
||||
input_shape_,
|
||||
3);
|
||||
|
||||
const int64_t paraformer_length_shape[1] = {1};
|
||||
std::vector<int32_t> paraformer_length;
|
||||
paraformer_length.emplace_back(num_frames);
|
||||
Ort::Value onnx_feats_len = Ort::Value::CreateTensor<int32_t>(
|
||||
m_memoryInfo, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
|
||||
|
||||
std::vector<Ort::Value> input_onnx;
|
||||
input_onnx.emplace_back(std::move(onnx_feats));
|
||||
input_onnx.emplace_back(std::move(onnx_feats_len));
|
||||
|
||||
auto encoder_tensor = encoder_session_->Run(Ort::RunOptions{nullptr}, en_szInputNames_.data(), input_onnx.data(), input_onnx.size(), en_szOutputNames_.data(), en_szOutputNames_.size());
|
||||
|
||||
// get enc_vec
|
||||
std::vector<int64_t> enc_shape = encoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
float* enc_data = encoder_tensor[0].GetTensorMutableData<float>();
|
||||
std::vector<std::vector<float>> enc_vec(enc_shape[1], std::vector<float>(enc_shape[2]));
|
||||
for (int i = 0; i < enc_shape[1]; i++) {
|
||||
for (int j = 0; j < enc_shape[2]; j++) {
|
||||
enc_vec[i][j] = enc_data[i * enc_shape[2] + j];
|
||||
}
|
||||
}
|
||||
|
||||
// get alpha_vec
|
||||
std::vector<int64_t> alpha_shape = encoder_tensor[2].GetTensorTypeAndShapeInfo().GetShape();
|
||||
float* alpha_data = encoder_tensor[2].GetTensorMutableData<float>();
|
||||
std::vector<float> alpha_vec(alpha_shape[1]);
|
||||
for (int i = 0; i < alpha_shape[1]; i++) {
|
||||
alpha_vec[i] = alpha_data[i];
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> list_frame;
|
||||
CifSearch(enc_vec, alpha_vec, input_finished, list_frame);
|
||||
|
||||
|
||||
if(list_frame.size()>0){
|
||||
// enc
|
||||
decoder_onnx.insert(decoder_onnx.begin(), std::move(encoder_tensor[0]));
|
||||
// enc_lens
|
||||
decoder_onnx.insert(decoder_onnx.begin()+1, std::move(encoder_tensor[1]));
|
||||
|
||||
// acoustic_embeds
|
||||
const int64_t emb_shape_[3] = {1, (int64_t)list_frame.size(), (int64_t)list_frame[0].size()};
|
||||
std::vector<float> emb_input;
|
||||
for (const auto &list_frame_: list_frame) {
|
||||
emb_input.insert(emb_input.end(), list_frame_.begin(), list_frame_.end());
|
||||
}
|
||||
Ort::Value onnx_emb = Ort::Value::CreateTensor<float>(
|
||||
m_memoryInfo,
|
||||
emb_input.data(),
|
||||
emb_input.size(),
|
||||
emb_shape_,
|
||||
3);
|
||||
decoder_onnx.insert(decoder_onnx.begin()+2, std::move(onnx_emb));
|
||||
|
||||
// acoustic_embeds_len
|
||||
const int64_t emb_length_shape[1] = {1};
|
||||
std::vector<int32_t> emb_length;
|
||||
emb_length.emplace_back(list_frame.size());
|
||||
Ort::Value onnx_emb_len = Ort::Value::CreateTensor<int32_t>(
|
||||
m_memoryInfo, emb_length.data(), emb_length.size(), emb_length_shape, 1);
|
||||
decoder_onnx.insert(decoder_onnx.begin()+3, std::move(onnx_emb_len));
|
||||
|
||||
auto decoder_tensor = decoder_session_->Run(Ort::RunOptions{nullptr}, de_szInputNames_.data(), decoder_onnx.data(), decoder_onnx.size(), de_szOutputNames_.data(), de_szOutputNames_.size());
|
||||
// fsmn cache
|
||||
try{
|
||||
decoder_onnx.clear();
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
return result;
|
||||
}
|
||||
for(int l=0;l<fsmn_layers;l++){
|
||||
decoder_onnx.emplace_back(std::move(decoder_tensor[2+l]));
|
||||
}
|
||||
|
||||
std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
|
||||
result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
|
||||
}
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
return result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
string ParaformerOnline::Forward(float* din, int len, bool input_finished)
|
||||
{
|
||||
std::vector<std::vector<float>> wav_feats;
|
||||
std::vector<float> waves(din, din+len);
|
||||
|
||||
string result="";
|
||||
try{
|
||||
if(len <16*60 && input_finished && !is_first_chunk){
|
||||
is_last_chunk = true;
|
||||
wav_feats = feats_cache_;
|
||||
result = ForwardChunk(wav_feats, is_last_chunk);
|
||||
// reset
|
||||
ResetCache();
|
||||
Reset();
|
||||
return result;
|
||||
}
|
||||
if(is_first_chunk){
|
||||
is_first_chunk = false;
|
||||
}
|
||||
ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
|
||||
if(wav_feats.size() == 0){
|
||||
return result;
|
||||
}
|
||||
|
||||
for (auto& row : wav_feats) {
|
||||
for (auto& val : row) {
|
||||
val *= sqrt_factor;
|
||||
}
|
||||
}
|
||||
|
||||
GetPosEmb(wav_feats, wav_feats.size(), wav_feats[0].size());
|
||||
if(input_finished){
|
||||
if(wav_feats.size()+chunk_size[2] <= chunk_size[1]){
|
||||
is_last_chunk = true;
|
||||
AddOverlapChunk(wav_feats, input_finished);
|
||||
}else{
|
||||
// first chunk
|
||||
std::vector<std::vector<float>> first_chunk;
|
||||
first_chunk.insert(first_chunk.begin(), wav_feats.begin(), wav_feats.end());
|
||||
AddOverlapChunk(first_chunk, input_finished);
|
||||
string str_first_chunk = ForwardChunk(first_chunk, is_last_chunk);
|
||||
|
||||
// last chunk
|
||||
is_last_chunk = true;
|
||||
std::vector<std::vector<float>> last_chunk;
|
||||
last_chunk.insert(last_chunk.begin(), wav_feats.end()-(wav_feats.size()+chunk_size[2]-chunk_size[1]), wav_feats.end());
|
||||
AddOverlapChunk(last_chunk, input_finished);
|
||||
string str_last_chunk = ForwardChunk(last_chunk, is_last_chunk);
|
||||
|
||||
result = str_first_chunk+str_last_chunk;
|
||||
// reset
|
||||
ResetCache();
|
||||
Reset();
|
||||
return result;
|
||||
}
|
||||
}else{
|
||||
AddOverlapChunk(wav_feats, input_finished);
|
||||
}
|
||||
|
||||
result = ForwardChunk(wav_feats, is_last_chunk);
|
||||
if(input_finished){
|
||||
// reset
|
||||
ResetCache();
|
||||
Reset();
|
||||
}
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
ParaformerOnline::~ParaformerOnline()
|
||||
{
|
||||
}
|
||||
|
||||
string ParaformerOnline::Rescoring()
|
||||
{
|
||||
LOG(ERROR)<<"Not Imp!!!!!!";
|
||||
return "";
|
||||
}
|
||||
} // namespace funasr
|
||||
111
funasr/runtime/onnxruntime/src/paraformer-online.h
Normal file
111
funasr/runtime/onnxruntime/src/paraformer-online.h
Normal file
@ -0,0 +1,111 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
* MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace funasr {
|
||||
|
||||
class ParaformerOnline : public Model {
|
||||
/**
|
||||
* Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
* ParaformerOnline: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
* https://arxiv.org/pdf/2206.08317.pdf
|
||||
*/
|
||||
private:
|
||||
|
||||
void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
|
||||
std::vector<float> &waves);
|
||||
int OnlineLfrCmvn(vector<vector<float>> &wav_feats, bool input_finished);
|
||||
void GetPosEmb(std::vector<std::vector<float>> &wav_feats, int timesteps, int feat_dim);
|
||||
void CifSearch(std::vector<std::vector<float>> hidden, std::vector<float> alphas, bool is_final, std::vector<std::vector<float>> &list_frame);
|
||||
|
||||
static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
|
||||
int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
|
||||
if (frame_num >= 1 && sample_length >= frame_sample_length)
|
||||
return frame_num;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
void InitOnline(
|
||||
knf::FbankOptions &fbank_opts,
|
||||
std::shared_ptr<Ort::Session> &encoder_session,
|
||||
std::shared_ptr<Ort::Session> &decoder_session,
|
||||
vector<const char*> &en_szInputNames,
|
||||
vector<const char*> &en_szOutputNames,
|
||||
vector<const char*> &de_szInputNames,
|
||||
vector<const char*> &de_szOutputNames,
|
||||
vector<float> &means_list,
|
||||
vector<float> &vars_list);
|
||||
|
||||
Paraformer* para_handle_ = nullptr;
|
||||
// from para_handle_
|
||||
knf::FbankOptions fbank_opts_;
|
||||
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
|
||||
std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
|
||||
Ort::SessionOptions session_options_;
|
||||
vector<const char*> en_szInputNames_;
|
||||
vector<const char*> en_szOutputNames_;
|
||||
vector<const char*> de_szInputNames_;
|
||||
vector<const char*> de_szOutputNames_;
|
||||
vector<float> means_list_;
|
||||
vector<float> vars_list_;
|
||||
// configs from para_handle_
|
||||
int frame_length = 25;
|
||||
int frame_shift = 10;
|
||||
int n_mels = 80;
|
||||
int lfr_m = PARA_LFR_M;
|
||||
int lfr_n = PARA_LFR_N;
|
||||
int encoder_size = 512;
|
||||
int fsmn_layers = 16;
|
||||
int fsmn_lorder = 10;
|
||||
int fsmn_dims = 512;
|
||||
float cif_threshold = 1.0;
|
||||
float tail_alphas = 0.45;
|
||||
|
||||
// configs
|
||||
int feat_dims = lfr_m*n_mels;
|
||||
std::vector<int> chunk_size = {5,10,5};
|
||||
int frame_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_length;
|
||||
int frame_shift_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_shift;
|
||||
|
||||
// The reserved waveforms by fbank
|
||||
std::vector<float> reserve_waveforms_;
|
||||
// waveforms reserved after last shift position
|
||||
std::vector<float> input_cache_;
|
||||
// lfr reserved cache
|
||||
std::vector<std::vector<float>> lfr_splice_cache_;
|
||||
// position index cache
|
||||
int start_idx_cache_ = 0;
|
||||
// cif alpha
|
||||
std::vector<float> alphas_cache_;
|
||||
std::vector<std::vector<float>> hidden_cache_;
|
||||
std::vector<std::vector<float>> feats_cache_;
|
||||
// fsmn init caches
|
||||
std::vector<float> fsmn_init_cache_;
|
||||
std::vector<Ort::Value> decoder_onnx;
|
||||
|
||||
bool is_first_chunk = true;
|
||||
bool is_last_chunk = false;
|
||||
double sqrt_factor;
|
||||
|
||||
public:
|
||||
ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size);
|
||||
~ParaformerOnline();
|
||||
void Reset();
|
||||
void ResetCache();
|
||||
void InitCache();
|
||||
void ExtractFeats(float sample_rate, vector<vector<float>> &wav_feats, vector<float> &waves, bool input_finished);
|
||||
void AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
|
||||
|
||||
string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
|
||||
string Forward(float* din, int len, bool input_finished);
|
||||
string Rescoring();
|
||||
// 2pass
|
||||
std::string online_res;
|
||||
int chunk_len;
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
@ -10,29 +10,30 @@ using namespace std;
|
||||
namespace funasr {
|
||||
|
||||
Paraformer::Paraformer()
|
||||
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
|
||||
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options_{}{
|
||||
}
|
||||
|
||||
// offline
|
||||
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
|
||||
// knf options
|
||||
fbank_opts.frame_opts.dither = 0;
|
||||
fbank_opts.mel_opts.num_bins = 80;
|
||||
fbank_opts.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
|
||||
fbank_opts.frame_opts.window_type = "hamming";
|
||||
fbank_opts.frame_opts.frame_shift_ms = 10;
|
||||
fbank_opts.frame_opts.frame_length_ms = 25;
|
||||
fbank_opts.energy_floor = 0;
|
||||
fbank_opts.mel_opts.debug_mel = false;
|
||||
fbank_opts_.frame_opts.dither = 0;
|
||||
fbank_opts_.mel_opts.num_bins = n_mels;
|
||||
fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
|
||||
fbank_opts_.frame_opts.window_type = window_type;
|
||||
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
|
||||
fbank_opts_.frame_opts.frame_length_ms = frame_length;
|
||||
fbank_opts_.energy_floor = 0;
|
||||
fbank_opts_.mel_opts.debug_mel = false;
|
||||
// fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts);
|
||||
|
||||
// session_options.SetInterOpNumThreads(1);
|
||||
session_options.SetIntraOpNumThreads(thread_num);
|
||||
session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
|
||||
// session_options_.SetInterOpNumThreads(1);
|
||||
session_options_.SetIntraOpNumThreads(thread_num);
|
||||
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
|
||||
// DisableCpuMemArena can improve performance
|
||||
session_options.DisableCpuMemArena();
|
||||
session_options_.DisableCpuMemArena();
|
||||
|
||||
try {
|
||||
m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
|
||||
m_session_ = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options_);
|
||||
LOG(INFO) << "Successfully load model from " << am_model;
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load am onnx model: " << e.what();
|
||||
@ -40,14 +41,14 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
|
||||
}
|
||||
|
||||
string strName;
|
||||
GetInputName(m_session.get(), strName);
|
||||
GetInputName(m_session_.get(), strName);
|
||||
m_strInputNames.push_back(strName.c_str());
|
||||
GetInputName(m_session.get(), strName,1);
|
||||
GetInputName(m_session_.get(), strName,1);
|
||||
m_strInputNames.push_back(strName);
|
||||
|
||||
GetOutputName(m_session.get(), strName);
|
||||
GetOutputName(m_session_.get(), strName);
|
||||
m_strOutputNames.push_back(strName);
|
||||
GetOutputName(m_session.get(), strName,1);
|
||||
GetOutputName(m_session_.get(), strName,1);
|
||||
m_strOutputNames.push_back(strName);
|
||||
|
||||
for (auto& item : m_strInputNames)
|
||||
@ -58,6 +59,152 @@ void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn
|
||||
LoadCmvn(am_cmvn.c_str());
|
||||
}
|
||||
|
||||
// online
|
||||
void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
|
||||
|
||||
LoadOnlineConfigFromYaml(am_config.c_str());
|
||||
// knf options
|
||||
fbank_opts_.frame_opts.dither = 0;
|
||||
fbank_opts_.mel_opts.num_bins = n_mels;
|
||||
fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
|
||||
fbank_opts_.frame_opts.window_type = window_type;
|
||||
fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
|
||||
fbank_opts_.frame_opts.frame_length_ms = frame_length;
|
||||
fbank_opts_.energy_floor = 0;
|
||||
fbank_opts_.mel_opts.debug_mel = false;
|
||||
|
||||
// session_options_.SetInterOpNumThreads(1);
|
||||
session_options_.SetIntraOpNumThreads(thread_num);
|
||||
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
|
||||
// DisableCpuMemArena can improve performance
|
||||
session_options_.DisableCpuMemArena();
|
||||
|
||||
try {
|
||||
encoder_session_ = std::make_unique<Ort::Session>(env_, en_model.c_str(), session_options_);
|
||||
LOG(INFO) << "Successfully load model from " << en_model;
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load am encoder model: " << e.what();
|
||||
exit(0);
|
||||
}
|
||||
|
||||
try {
|
||||
decoder_session_ = std::make_unique<Ort::Session>(env_, de_model.c_str(), session_options_);
|
||||
LOG(INFO) << "Successfully load model from " << de_model;
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load am decoder model: " << e.what();
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// encoder
|
||||
string strName;
|
||||
GetInputName(encoder_session_.get(), strName);
|
||||
en_strInputNames.push_back(strName.c_str());
|
||||
GetInputName(encoder_session_.get(), strName,1);
|
||||
en_strInputNames.push_back(strName);
|
||||
|
||||
GetOutputName(encoder_session_.get(), strName);
|
||||
en_strOutputNames.push_back(strName);
|
||||
GetOutputName(encoder_session_.get(), strName,1);
|
||||
en_strOutputNames.push_back(strName);
|
||||
GetOutputName(encoder_session_.get(), strName,2);
|
||||
en_strOutputNames.push_back(strName);
|
||||
|
||||
for (auto& item : en_strInputNames)
|
||||
en_szInputNames_.push_back(item.c_str());
|
||||
for (auto& item : en_strOutputNames)
|
||||
en_szOutputNames_.push_back(item.c_str());
|
||||
|
||||
// decoder
|
||||
int de_input_len = 4 + fsmn_layers;
|
||||
int de_out_len = 2 + fsmn_layers;
|
||||
for(int i=0;i<de_input_len; i++){
|
||||
GetInputName(decoder_session_.get(), strName, i);
|
||||
de_strInputNames.push_back(strName.c_str());
|
||||
}
|
||||
|
||||
for(int i=0;i<de_out_len; i++){
|
||||
GetOutputName(decoder_session_.get(), strName,i);
|
||||
de_strOutputNames.push_back(strName);
|
||||
}
|
||||
|
||||
for (auto& item : de_strInputNames)
|
||||
de_szInputNames_.push_back(item.c_str());
|
||||
for (auto& item : de_strOutputNames)
|
||||
de_szOutputNames_.push_back(item.c_str());
|
||||
|
||||
vocab = new Vocab(am_config.c_str());
|
||||
LoadCmvn(am_cmvn.c_str());
|
||||
}
|
||||
|
||||
// 2pass
|
||||
void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
|
||||
// online
|
||||
InitAsr(en_model, de_model, am_cmvn, am_config, thread_num);
|
||||
|
||||
// offline
|
||||
try {
|
||||
m_session_ = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options_);
|
||||
LOG(INFO) << "Successfully load model from " << am_model;
|
||||
} catch (std::exception const &e) {
|
||||
LOG(ERROR) << "Error when load am onnx model: " << e.what();
|
||||
exit(0);
|
||||
}
|
||||
|
||||
string strName;
|
||||
GetInputName(m_session_.get(), strName);
|
||||
m_strInputNames.push_back(strName.c_str());
|
||||
GetInputName(m_session_.get(), strName,1);
|
||||
m_strInputNames.push_back(strName);
|
||||
|
||||
GetOutputName(m_session_.get(), strName);
|
||||
m_strOutputNames.push_back(strName);
|
||||
GetOutputName(m_session_.get(), strName,1);
|
||||
m_strOutputNames.push_back(strName);
|
||||
|
||||
for (auto& item : m_strInputNames)
|
||||
m_szInputNames.push_back(item.c_str());
|
||||
for (auto& item : m_strOutputNames)
|
||||
m_szOutputNames.push_back(item.c_str());
|
||||
}
|
||||
|
||||
void Paraformer::LoadOnlineConfigFromYaml(const char* filename){
|
||||
|
||||
YAML::Node config;
|
||||
try{
|
||||
config = YAML::LoadFile(filename);
|
||||
}catch(exception const &e){
|
||||
LOG(ERROR) << "Error loading file, yaml file error or not exist.";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
try{
|
||||
YAML::Node frontend_conf = config["frontend_conf"];
|
||||
YAML::Node encoder_conf = config["encoder_conf"];
|
||||
YAML::Node decoder_conf = config["decoder_conf"];
|
||||
YAML::Node predictor_conf = config["predictor_conf"];
|
||||
|
||||
this->window_type = frontend_conf["window"].as<string>();
|
||||
this->n_mels = frontend_conf["n_mels"].as<int>();
|
||||
this->frame_length = frontend_conf["frame_length"].as<int>();
|
||||
this->frame_shift = frontend_conf["frame_shift"].as<int>();
|
||||
this->lfr_m = frontend_conf["lfr_m"].as<int>();
|
||||
this->lfr_n = frontend_conf["lfr_n"].as<int>();
|
||||
|
||||
this->encoder_size = encoder_conf["output_size"].as<int>();
|
||||
this->fsmn_dims = encoder_conf["output_size"].as<int>();
|
||||
|
||||
this->fsmn_layers = decoder_conf["num_blocks"].as<int>();
|
||||
this->fsmn_lorder = decoder_conf["kernel_size"].as<int>()-1;
|
||||
|
||||
this->cif_threshold = predictor_conf["threshold"].as<double>();
|
||||
this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
|
||||
|
||||
}catch(exception const &e){
|
||||
LOG(ERROR) << "Error when load argument from vad config YAML.";
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
Paraformer::~Paraformer()
|
||||
{
|
||||
if(vocab)
|
||||
@ -69,7 +216,7 @@ void Paraformer::Reset()
|
||||
}
|
||||
|
||||
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
|
||||
knf::OnlineFbank fbank_(fbank_opts);
|
||||
knf::OnlineFbank fbank_(fbank_opts_);
|
||||
std::vector<float> buf(len);
|
||||
for (int32_t i = 0; i != len; ++i) {
|
||||
buf[i] = waves[i] * 32768;
|
||||
@ -77,7 +224,7 @@ vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int
|
||||
fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
|
||||
//fbank_->InputFinished();
|
||||
int32_t frames = fbank_.NumFramesReady();
|
||||
int32_t feature_dim = fbank_opts.mel_opts.num_bins;
|
||||
int32_t feature_dim = fbank_opts_.mel_opts.num_bins;
|
||||
vector<float> features(frames * feature_dim);
|
||||
float *p = features.data();
|
||||
|
||||
@ -108,7 +255,7 @@ void Paraformer::LoadCmvn(const char *filename)
|
||||
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
|
||||
if (means_lines[0] == "<LearnRateCoef>") {
|
||||
for (int j = 3; j < means_lines.size() - 1; j++) {
|
||||
means_list.push_back(stof(means_lines[j]));
|
||||
means_list_.push_back(stof(means_lines[j]));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -119,7 +266,7 @@ void Paraformer::LoadCmvn(const char *filename)
|
||||
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
|
||||
if (vars_lines[0] == "<LearnRateCoef>") {
|
||||
for (int j = 3; j < vars_lines.size() - 1; j++) {
|
||||
vars_list.push_back(stof(vars_lines[j])*scale);
|
||||
vars_list_.push_back(stof(vars_lines[j])*scale);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
@ -143,11 +290,11 @@ string Paraformer::GreedySearch(float * in, int n_len, int64_t token_nums)
|
||||
|
||||
vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
|
||||
{
|
||||
int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
|
||||
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||
int32_t in_num_frames = in.size() / in_feat_dim;
|
||||
int32_t out_num_frames =
|
||||
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
|
||||
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
|
||||
(in_num_frames - lfr_m) / lfr_n + 1;
|
||||
int32_t out_feat_dim = in_feat_dim * lfr_m;
|
||||
|
||||
std::vector<float> out(out_num_frames * out_feat_dim);
|
||||
|
||||
@ -158,7 +305,7 @@ vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
|
||||
std::copy(p_in, p_in + out_feat_dim, p_out);
|
||||
|
||||
p_out += out_feat_dim;
|
||||
p_in += lfr_window_shift * in_feat_dim;
|
||||
p_in += lfr_n * in_feat_dim;
|
||||
}
|
||||
|
||||
return out;
|
||||
@ -166,29 +313,29 @@ vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
|
||||
|
||||
void Paraformer::ApplyCmvn(std::vector<float> *v)
|
||||
{
|
||||
int32_t dim = means_list.size();
|
||||
int32_t dim = means_list_.size();
|
||||
int32_t num_frames = v->size() / dim;
|
||||
|
||||
float *p = v->data();
|
||||
|
||||
for (int32_t i = 0; i != num_frames; ++i) {
|
||||
for (int32_t k = 0; k != dim; ++k) {
|
||||
p[k] = (p[k] + means_list[k]) * vars_list[k];
|
||||
p[k] = (p[k] + means_list_[k]) * vars_list_[k];
|
||||
}
|
||||
|
||||
p += dim;
|
||||
}
|
||||
}
|
||||
|
||||
string Paraformer::Forward(float* din, int len, int flag)
|
||||
string Paraformer::Forward(float* din, int len, bool input_finished)
|
||||
{
|
||||
|
||||
int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
|
||||
int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
|
||||
std::vector<float> wav_feats = FbankKaldi(MODEL_SAMPLE_RATE, din, len);
|
||||
wav_feats = ApplyLfr(wav_feats);
|
||||
ApplyCmvn(&wav_feats);
|
||||
|
||||
int32_t feat_dim = lfr_window_size*in_feat_dim;
|
||||
int32_t feat_dim = lfr_m*in_feat_dim;
|
||||
int32_t num_frames = wav_feats.size() / feat_dim;
|
||||
|
||||
#ifdef _WIN_X86
|
||||
@ -216,7 +363,7 @@ string Paraformer::Forward(float* din, int len, int flag)
|
||||
|
||||
string result;
|
||||
try {
|
||||
auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
|
||||
auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
|
||||
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
|
||||
@ -232,13 +379,6 @@ string Paraformer::Forward(float* din, int len, int flag)
|
||||
return result;
|
||||
}
|
||||
|
||||
string Paraformer::ForwardChunk(float* din, int len, int flag)
|
||||
{
|
||||
|
||||
LOG(ERROR)<<"Not Imp!!!!!!";
|
||||
return "";
|
||||
}
|
||||
|
||||
string Paraformer::Rescoring()
|
||||
{
|
||||
LOG(ERROR)<<"Not Imp!!!!!!";
|
||||
|
||||
@ -15,38 +15,66 @@ namespace funasr {
|
||||
* https://arxiv.org/pdf/2206.08317.pdf
|
||||
*/
|
||||
private:
|
||||
//std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
knf::FbankOptions fbank_opts;
|
||||
|
||||
Vocab* vocab = nullptr;
|
||||
vector<float> means_list;
|
||||
vector<float> vars_list;
|
||||
const float scale = 22.6274169979695;
|
||||
int32_t lfr_window_size = 7;
|
||||
int32_t lfr_window_shift = 6;
|
||||
//const float scale = 22.6274169979695;
|
||||
const float scale = 1.0;
|
||||
|
||||
void LoadOnlineConfigFromYaml(const char* filename);
|
||||
void LoadCmvn(const char *filename);
|
||||
vector<float> ApplyLfr(const vector<float> &in);
|
||||
void ApplyCmvn(vector<float> *v);
|
||||
string GreedySearch( float* in, int n_len, int64_t token_nums);
|
||||
|
||||
std::shared_ptr<Ort::Session> m_session = nullptr;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions session_options;
|
||||
|
||||
vector<string> m_strInputNames, m_strOutputNames;
|
||||
vector<const char*> m_szInputNames;
|
||||
vector<const char*> m_szOutputNames;
|
||||
|
||||
public:
|
||||
Paraformer();
|
||||
~Paraformer();
|
||||
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
|
||||
// online
|
||||
void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
|
||||
// 2pass
|
||||
void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
|
||||
void Reset();
|
||||
vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
|
||||
string ForwardChunk(float* din, int len, int flag);
|
||||
string Forward(float* din, int len, int flag);
|
||||
string Forward(float* din, int len, bool input_finished=true);
|
||||
string GreedySearch( float* in, int n_len, int64_t token_nums);
|
||||
string Rescoring();
|
||||
|
||||
knf::FbankOptions fbank_opts_;
|
||||
vector<float> means_list_;
|
||||
vector<float> vars_list_;
|
||||
int lfr_m = PARA_LFR_M;
|
||||
int lfr_n = PARA_LFR_N;
|
||||
|
||||
// paraformer-offline
|
||||
std::shared_ptr<Ort::Session> m_session_ = nullptr;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions session_options_;
|
||||
|
||||
vector<string> m_strInputNames, m_strOutputNames;
|
||||
vector<const char*> m_szInputNames;
|
||||
vector<const char*> m_szOutputNames;
|
||||
|
||||
// paraformer-online
|
||||
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
|
||||
std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
|
||||
vector<string> en_strInputNames, en_strOutputNames;
|
||||
vector<const char*> en_szInputNames_;
|
||||
vector<const char*> en_szOutputNames_;
|
||||
vector<string> de_strInputNames, de_strOutputNames;
|
||||
vector<const char*> de_szInputNames_;
|
||||
vector<const char*> de_szOutputNames_;
|
||||
|
||||
string window_type = "hamming";
|
||||
int frame_length = 25;
|
||||
int frame_shift = 10;
|
||||
int n_mels = 80;
|
||||
int encoder_size = 512;
|
||||
int fsmn_layers = 16;
|
||||
int fsmn_lorder = 10;
|
||||
int fsmn_dims = 512;
|
||||
float cif_threshold = 1.0;
|
||||
float tail_alphas = 0.45;
|
||||
|
||||
|
||||
};
|
||||
|
||||
} // namespace funasr
|
||||
|
||||
@ -33,18 +33,20 @@ using namespace std;
|
||||
#include "model.h"
|
||||
#include "vad-model.h"
|
||||
#include "punc-model.h"
|
||||
#include "offline-stream.h"
|
||||
#include "tokenizer.h"
|
||||
#include "ct-transformer.h"
|
||||
#include "ct-transformer-online.h"
|
||||
#include "e2e-vad.h"
|
||||
#include "fsmn-vad.h"
|
||||
#include "fsmn-vad-online.h"
|
||||
#include "vocab.h"
|
||||
#include "audio.h"
|
||||
#include "fsmn-vad-online.h"
|
||||
#include "tensor.h"
|
||||
#include "util.h"
|
||||
#include "resample.h"
|
||||
#include "paraformer.h"
|
||||
#include "paraformer-online.h"
|
||||
#include "offline-stream.h"
|
||||
#include "tpass-stream.h"
|
||||
#include "tpass-online-stream.h"
|
||||
#include "funasrruntime.h"
|
||||
|
||||
29
funasr/runtime/onnxruntime/src/tpass-online-stream.cpp
Normal file
29
funasr/runtime/onnxruntime/src/tpass-online-stream.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
#include "precomp.h"
|
||||
#include <unistd.h>
|
||||
|
||||
namespace funasr {
|
||||
TpassOnlineStream::TpassOnlineStream(TpassStream* tpass_stream, std::vector<int> chunk_size){
|
||||
TpassStream* tpass_obj = (TpassStream*)tpass_stream;
|
||||
if(tpass_obj->vad_handle){
|
||||
vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(tpass_obj->vad_handle).get());
|
||||
}else{
|
||||
LOG(ERROR)<<"asr_handle is null";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if(tpass_obj->asr_handle){
|
||||
asr_online_handle = make_unique<ParaformerOnline>((Paraformer*)(tpass_obj->asr_handle).get(), chunk_size);
|
||||
}else{
|
||||
LOG(ERROR)<<"asr_handle is null";
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size)
|
||||
{
|
||||
TpassOnlineStream *mm;
|
||||
mm =new TpassOnlineStream((TpassStream*)tpass_stream, chunk_size);
|
||||
return mm;
|
||||
}
|
||||
|
||||
} // namespace funasr
|
||||
87
funasr/runtime/onnxruntime/src/tpass-stream.cpp
Normal file
87
funasr/runtime/onnxruntime/src/tpass-stream.cpp
Normal file
@ -0,0 +1,87 @@
|
||||
#include "precomp.h"
|
||||
#include <unistd.h>
|
||||
|
||||
namespace funasr {
|
||||
TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
// VAD model
|
||||
if(model_path.find(VAD_DIR) != model_path.end()){
|
||||
string vad_model_path;
|
||||
string vad_cmvn_path;
|
||||
string vad_config_path;
|
||||
|
||||
vad_model_path = PathAppend(model_path.at(VAD_DIR), MODEL_NAME);
|
||||
if(model_path.find(VAD_QUANT) != model_path.end() && model_path.at(VAD_QUANT) == "true"){
|
||||
vad_model_path = PathAppend(model_path.at(VAD_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
|
||||
vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
|
||||
if (access(vad_model_path.c_str(), F_OK) != 0 ||
|
||||
access(vad_cmvn_path.c_str(), F_OK) != 0 ||
|
||||
access(vad_config_path.c_str(), F_OK) != 0 )
|
||||
{
|
||||
LOG(INFO) << "VAD model file is not exist, skip load vad model.";
|
||||
}else{
|
||||
vad_handle = make_unique<FsmnVad>();
|
||||
vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
|
||||
use_vad = true;
|
||||
}
|
||||
}
|
||||
|
||||
// AM model
|
||||
if(model_path.find(OFFLINE_MODEL_DIR) != model_path.end() && model_path.find(ONLINE_MODEL_DIR) != model_path.end()){
|
||||
// 2pass
|
||||
string am_model_path;
|
||||
string en_model_path;
|
||||
string de_model_path;
|
||||
string am_cmvn_path;
|
||||
string am_config_path;
|
||||
|
||||
am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), MODEL_NAME);
|
||||
en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), ENCODER_NAME);
|
||||
de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), DECODER_NAME);
|
||||
if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
|
||||
am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), QUANT_MODEL_NAME);
|
||||
en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_ENCODER_NAME);
|
||||
de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_DECODER_NAME);
|
||||
}
|
||||
am_cmvn_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CMVN_NAME);
|
||||
am_config_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CONFIG_NAME);
|
||||
|
||||
asr_handle = make_unique<Paraformer>();
|
||||
asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
|
||||
}else{
|
||||
LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// PUNC model
|
||||
if(model_path.find(PUNC_DIR) != model_path.end()){
|
||||
string punc_model_path;
|
||||
string punc_config_path;
|
||||
|
||||
punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
|
||||
if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
|
||||
punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
|
||||
}
|
||||
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
|
||||
|
||||
if (access(punc_model_path.c_str(), F_OK) != 0 ||
|
||||
access(punc_config_path.c_str(), F_OK) != 0 )
|
||||
{
|
||||
LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
|
||||
}else{
|
||||
punc_online_handle = make_unique<CTTransformerOnline>();
|
||||
punc_online_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
|
||||
use_punc = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num)
|
||||
{
|
||||
TpassStream *mm;
|
||||
mm = new TpassStream(model_path, thread_num);
|
||||
return mm;
|
||||
}
|
||||
} // namespace funasr
|
||||
@ -1,73 +1,27 @@
|
||||
# Service with grpc-python
|
||||
We can send streaming audio data to server in real-time with grpc client every 10 ms e.g., and get transcribed text when stop speaking.
|
||||
The audio data is in streaming, the asr inference process is in offline.
|
||||
# GRPC python Client for 2pass decoding
|
||||
The client can send streaming or full audio data to server as you wish, and get transcribed text once the server respond (depends on mode)
|
||||
|
||||
## For the Server
|
||||
|
||||
### Prepare server environment
|
||||
Install the modelscope and funasr
|
||||
In the demo client, audio_chunk_duration is set to 1000ms, and send_interval is set to 100ms
|
||||
|
||||
### 1. Install the requirements
|
||||
```shell
|
||||
pip install -U modelscope funasr
|
||||
# For the users in China, you could install with the command:
|
||||
# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR/funasr/runtime/python/grpc
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Install the requirements
|
||||
|
||||
```shell
|
||||
cd funasr/runtime/python/grpc
|
||||
pip install -r requirements_server.txt
|
||||
```
|
||||
|
||||
|
||||
### Generate protobuf file
|
||||
Run on server, the two generated pb files are both used for server and client
|
||||
|
||||
### 2. Generate protobuf file
|
||||
```shell
|
||||
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
|
||||
# regenerate it only when you make changes to ./proto/paraformer.proto file.
|
||||
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
|
||||
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
|
||||
```
|
||||
|
||||
### Start grpc server
|
||||
|
||||
```
|
||||
# Start server.
|
||||
python grpc_main_server.py --port 10095 --backend pipeline
|
||||
```
|
||||
|
||||
|
||||
## For the client
|
||||
|
||||
### Install the requirements
|
||||
|
||||
```shell
|
||||
git clone https://github.com/alibaba/FunASR.git && cd FunASR
|
||||
cd funasr/runtime/python/grpc
|
||||
pip install -r requirements_client.txt
|
||||
```
|
||||
|
||||
### Generate protobuf file
|
||||
Run on server, the two generated pb files are both used for server and client
|
||||
|
||||
```shell
|
||||
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
|
||||
# regenerate it only when you make changes to ./proto/paraformer.proto file.
|
||||
python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
|
||||
```
|
||||
|
||||
### Start grpc client
|
||||
### 3. Start grpc client
|
||||
```
|
||||
# Start client.
|
||||
python grpc_main_client_mic.py --host 127.0.0.1 --port 10095
|
||||
python grpc_main_client.py --host 127.0.0.1 --port 10100 --wav_path /path/to/your_test_wav.wav
|
||||
```
|
||||
|
||||
|
||||
## Workflow in desgin
|
||||
|
||||
<div align="left"><img src="proto/workflow.png" width="400"/>
|
||||
|
||||
## Acknowledge
|
||||
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
|
||||
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
|
||||
2. We acknowledge burkliu (刘柏基, liubaiji@xverse.cn) for contributing the grpc service.
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
import queue
|
||||
import paraformer_pb2
|
||||
|
||||
def transcribe_audio_bytes(stub, chunk, user='zksz', language='zh-CN', speaking = True, isEnd = False):
|
||||
req = paraformer_pb2.Request()
|
||||
if chunk is not None:
|
||||
req.audio_data = chunk
|
||||
req.user = user
|
||||
req.language = language
|
||||
req.speaking = speaking
|
||||
req.isEnd = isEnd
|
||||
my_queue = queue.SimpleQueue()
|
||||
my_queue.put(req)
|
||||
return stub.Recognize(iter(my_queue.get, None))
|
||||
|
||||
|
||||
|
||||
@ -1,62 +1,78 @@
|
||||
import grpc
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import soundfile as sf
|
||||
import logging
|
||||
import argparse
|
||||
import soundfile as sf
|
||||
import time
|
||||
|
||||
from grpc_client import transcribe_audio_bytes
|
||||
from paraformer_pb2_grpc import ASRStub
|
||||
import grpc
|
||||
import paraformer_pb2_grpc
|
||||
from paraformer_pb2 import Request, WavFormat, DecodeMode
|
||||
|
||||
# send the audio data once
|
||||
async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
|
||||
with grpc.insecure_channel(grpc_uri) as channel:
|
||||
stub = ASRStub(channel)
|
||||
for line in wav_scp:
|
||||
wav_file = line.split()[1]
|
||||
wav, _ = sf.read(wav_file, dtype='int16')
|
||||
|
||||
b = time.time()
|
||||
response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
|
||||
resp = response.next()
|
||||
text = ''
|
||||
if 'decoding' == resp.action:
|
||||
resp = response.next()
|
||||
if 'finish' == resp.action:
|
||||
text = json.loads(resp.sentence)['text']
|
||||
response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
|
||||
res= {'text': text, 'time': time.time() - b}
|
||||
print(res)
|
||||
class GrpcClient:
|
||||
def __init__(self, wav_path, uri, mode):
|
||||
self.wav, self.sampling_rate = sf.read(wav_path, dtype='int16')
|
||||
self.wav_format = WavFormat.pcm
|
||||
self.audio_chunk_duration = 1000 # ms
|
||||
self.audio_chunk_size = int(self.sampling_rate * self.audio_chunk_duration / 1000)
|
||||
self.send_interval = 100 # ms
|
||||
self.mode = mode
|
||||
|
||||
async def test(args):
|
||||
wav_scp = open(args.wav_scp, "r").readlines()
|
||||
uri = '{}:{}'.format(args.host, args.port)
|
||||
res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
|
||||
# connect to grpc server
|
||||
channel = grpc.insecure_channel(uri)
|
||||
self.stub = paraformer_pb2_grpc.ASRStub(channel)
|
||||
|
||||
# start request
|
||||
for respond in self.stub.Recognize(self.request_iterator()):
|
||||
logging.info("[receive] mode {}, text {}, is final {}".format(
|
||||
DecodeMode.Name(respond.mode), respond.text, respond.is_final))
|
||||
|
||||
def request_iterator(self, mode = DecodeMode.two_pass):
|
||||
is_first_pack = True
|
||||
is_final = False
|
||||
for start in range(0, len(self.wav), self.audio_chunk_size):
|
||||
request = Request()
|
||||
audio_chunk = self.wav[start : start + self.audio_chunk_size]
|
||||
|
||||
if is_first_pack:
|
||||
is_first_pack = False
|
||||
request.sampling_rate = self.sampling_rate
|
||||
request.mode = self.mode
|
||||
request.wav_format = self.wav_format
|
||||
if request.mode == DecodeMode.two_pass or request.mode == DecodeMode.online:
|
||||
request.chunk_size.extend([5, 10, 5])
|
||||
|
||||
if start + self.audio_chunk_size >= len(self.wav):
|
||||
is_final = True
|
||||
request.is_final = is_final
|
||||
request.audio_data = audio_chunk.tobytes()
|
||||
logging.info("[request] audio_data len {}, is final {}".format(
|
||||
len(request.audio_data), request.is_final)) # int16 = 2bytes
|
||||
time.sleep(self.send_interval / 1000)
|
||||
yield request
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
required=False,
|
||||
help="grpc server host ip")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=10108,
|
||||
required=False,
|
||||
help="grpc server port")
|
||||
parser.add_argument("--user_allowed",
|
||||
type=str,
|
||||
default="project1_user1",
|
||||
help="allowed user for grpc client")
|
||||
parser.add_argument("--sample_rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="audio sample_rate from client")
|
||||
parser.add_argument("--wav_scp",
|
||||
type=str,
|
||||
required=True,
|
||||
help="audio wav scp")
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(test(args))
|
||||
logging.basicConfig(filename="", format="%(asctime)s %(message)s", level=logging.INFO)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
required=False,
|
||||
help="grpc server host ip")
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=10100,
|
||||
required=False,
|
||||
help="grpc server port")
|
||||
parser.add_argument("--wav_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="audio wav path")
|
||||
args = parser.parse_args()
|
||||
|
||||
for mode in [DecodeMode.offline, DecodeMode.online, DecodeMode.two_pass]:
|
||||
mode_name = DecodeMode.Name(mode)
|
||||
logging.info("[request] start requesting with mode {}".format(mode_name))
|
||||
|
||||
st = time.time()
|
||||
uri = '{}:{}'.format(args.host, args.port)
|
||||
client = GrpcClient(args.wav_path, uri, mode)
|
||||
logging.info("mode {}, time pass: {}".format(mode_name, time.time() - st))
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
import pyaudio
|
||||
import grpc
|
||||
import json
|
||||
import webrtcvad
|
||||
import time
|
||||
import asyncio
|
||||
import argparse
|
||||
|
||||
from grpc_client import transcribe_audio_bytes
|
||||
from paraformer_pb2_grpc import ASRStub
|
||||
|
||||
async def deal_chunk(sig_mic):
|
||||
global stub,SPEAKING,asr_user,language,sample_rate
|
||||
if vad.is_speech(sig_mic, sample_rate): #speaking
|
||||
SPEAKING = True
|
||||
response = transcribe_audio_bytes(stub, sig_mic, user=asr_user, language=language, speaking = True, isEnd = False) #speaking, send audio to server.
|
||||
else: #silence
|
||||
begin_time = 0
|
||||
if SPEAKING: #means we have some audio recorded, send recognize order to server.
|
||||
SPEAKING = False
|
||||
begin_time = int(round(time.time() * 1000))
|
||||
response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking = False, isEnd = False) #speak end, call server for recognize one sentence
|
||||
resp = response.next()
|
||||
if "decoding" == resp.action:
|
||||
resp = response.next() #TODO, blocking operation may leads to miss some audio clips. C++ multi-threading is preferred.
|
||||
if "finish" == resp.action:
|
||||
end_time = int(round(time.time() * 1000))
|
||||
print (json.loads(resp.sentence))
|
||||
print ("delay in ms: %d " % (end_time - begin_time))
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
async def record(host,port,sample_rate,mic_chunk,record_seconds,asr_user,language):
|
||||
with grpc.insecure_channel('{}:{}'.format(host, port)) as channel:
|
||||
global stub
|
||||
stub = ASRStub(channel)
|
||||
for i in range(0, int(sample_rate / mic_chunk * record_seconds)):
|
||||
|
||||
sig_mic = stream.read(mic_chunk,exception_on_overflow = False)
|
||||
await asyncio.create_task(deal_chunk(sig_mic))
|
||||
|
||||
#end grpc
|
||||
response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking = False, isEnd = True)
|
||||
print (response.next().action)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host",
|
||||
type=str,
|
||||
default="127.0.0.1",
|
||||
required=True,
|
||||
help="grpc server host ip")
|
||||
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=10095,
|
||||
required=True,
|
||||
help="grpc server port")
|
||||
|
||||
parser.add_argument("--user_allowed",
|
||||
type=str,
|
||||
default="project1_user1",
|
||||
help="allowed user for grpc client")
|
||||
|
||||
parser.add_argument("--sample_rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="audio sample_rate from client")
|
||||
|
||||
parser.add_argument("--mic_chunk",
|
||||
type=int,
|
||||
default=160,
|
||||
help="chunk size for mic")
|
||||
|
||||
parser.add_argument("--record_seconds",
|
||||
type=int,
|
||||
default=120,
|
||||
help="run specified seconds then exit ")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
SPEAKING = False
|
||||
asr_user = args.user_allowed
|
||||
sample_rate = args.sample_rate
|
||||
language = 'zh-CN'
|
||||
|
||||
|
||||
vad = webrtcvad.Vad()
|
||||
vad.set_mode(1)
|
||||
|
||||
FORMAT = pyaudio.paInt16
|
||||
CHANNELS = 1
|
||||
p = pyaudio.PyAudio()
|
||||
|
||||
stream = p.open(format=FORMAT,
|
||||
channels=CHANNELS,
|
||||
rate=args.sample_rate,
|
||||
input=True,
|
||||
frames_per_buffer=args.mic_chunk)
|
||||
|
||||
print("* recording")
|
||||
asyncio.run(record(args.host,args.port,args.sample_rate,args.mic_chunk,args.record_seconds,args.user_allowed,language))
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
p.terminate()
|
||||
print("recording stop")
|
||||
|
||||
|
||||
|
||||
@ -1,68 +0,0 @@
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
|
||||
import paraformer_pb2_grpc
|
||||
from grpc_server import ASRServicer
|
||||
|
||||
def serve(args):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
|
||||
# interceptors=(AuthInterceptor('Bearer mysecrettoken'),)
|
||||
)
|
||||
paraformer_pb2_grpc.add_ASRServicer_to_server(
|
||||
ASRServicer(args.user_allowed, args.model, args.sample_rate, args.backend, args.onnx_dir, vad_model=args.vad_model, punc_model=args.punc_model), server)
|
||||
port = "[::]:" + str(args.port)
|
||||
server.add_insecure_port(port)
|
||||
server.start()
|
||||
print("grpc server started!")
|
||||
server.wait_for_termination()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port",
|
||||
type=int,
|
||||
default=10095,
|
||||
required=True,
|
||||
help="grpc server port")
|
||||
|
||||
parser.add_argument("--user_allowed",
|
||||
type=str,
|
||||
default="project1_user1|project1_user2|project2_user3",
|
||||
help="allowed user for grpc client")
|
||||
|
||||
parser.add_argument("--model",
|
||||
type=str,
|
||||
default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
help="model from modelscope")
|
||||
parser.add_argument("--vad_model",
|
||||
type=str,
|
||||
default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
help="model from modelscope")
|
||||
|
||||
parser.add_argument("--punc_model",
|
||||
type=str,
|
||||
default="",
|
||||
help="model from modelscope")
|
||||
|
||||
parser.add_argument("--sample_rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="audio sample_rate from client")
|
||||
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
default="pipeline",
|
||||
choices=("pipeline", "onnxruntime"),
|
||||
help="backend, optional modelscope pipeline or onnxruntime")
|
||||
|
||||
parser.add_argument("--onnx_dir",
|
||||
type=str,
|
||||
default="/nfs/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
help="onnx model dir")
|
||||
|
||||
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args)
|
||||
@ -1,132 +0,0 @@
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import json
|
||||
import time
|
||||
|
||||
import paraformer_pb2_grpc
|
||||
from paraformer_pb2 import Response
|
||||
|
||||
|
||||
class ASRServicer(paraformer_pb2_grpc.ASRServicer):
|
||||
def __init__(self, user_allowed, model, sample_rate, backend, onnx_dir, vad_model='', punc_model=''):
|
||||
print("ASRServicer init")
|
||||
self.backend = backend
|
||||
self.init_flag = 0
|
||||
self.client_buffers = {}
|
||||
self.client_transcription = {}
|
||||
self.auth_user = user_allowed.split("|")
|
||||
if self.backend == "pipeline":
|
||||
try:
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install modelscope")
|
||||
self.inference_16k_pipeline = pipeline(task=Tasks.auto_speech_recognition, model=model, vad_model=vad_model, punc_model=punc_model)
|
||||
elif self.backend == "onnxruntime":
|
||||
try:
|
||||
from funasr_onnx import Paraformer
|
||||
except ImportError:
|
||||
raise ImportError(f"Please install onnxruntime environment")
|
||||
self.inference_16k_pipeline = Paraformer(model_dir=onnx_dir)
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def clear_states(self, user):
|
||||
self.clear_buffers(user)
|
||||
self.clear_transcriptions(user)
|
||||
|
||||
def clear_buffers(self, user):
|
||||
if user in self.client_buffers:
|
||||
del self.client_buffers[user]
|
||||
|
||||
def clear_transcriptions(self, user):
|
||||
if user in self.client_transcription:
|
||||
del self.client_transcription[user]
|
||||
|
||||
def disconnect(self, user):
|
||||
self.clear_states(user)
|
||||
print("Disconnecting user: %s" % str(user))
|
||||
|
||||
def Recognize(self, request_iterator, context):
|
||||
|
||||
|
||||
for req in request_iterator:
|
||||
if req.user not in self.auth_user:
|
||||
result = {}
|
||||
result["success"] = False
|
||||
result["detail"] = "Not Authorized user: %s " % req.user
|
||||
result["text"] = ""
|
||||
yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)
|
||||
elif req.isEnd: #end grpc
|
||||
print("asr end")
|
||||
self.disconnect(req.user)
|
||||
result = {}
|
||||
result["success"] = True
|
||||
result["detail"] = "asr end"
|
||||
result["text"] = ""
|
||||
yield Response(sentence=json.dumps(result), user=req.user, action="terminate",language=req.language)
|
||||
elif req.speaking: #continue speaking
|
||||
if req.audio_data is not None and len(req.audio_data) > 0:
|
||||
if req.user in self.client_buffers:
|
||||
self.client_buffers[req.user] += req.audio_data #append audio
|
||||
else:
|
||||
self.client_buffers[req.user] = req.audio_data
|
||||
result = {}
|
||||
result["success"] = True
|
||||
result["detail"] = "speaking"
|
||||
result["text"] = ""
|
||||
yield Response(sentence=json.dumps(result), user=req.user, action="speaking", language=req.language)
|
||||
elif not req.speaking: #silence
|
||||
if req.user not in self.client_buffers:
|
||||
result = {}
|
||||
result["success"] = True
|
||||
result["detail"] = "waiting_for_more_voice"
|
||||
result["text"] = ""
|
||||
yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
|
||||
else:
|
||||
begin_time = int(round(time.time() * 1000))
|
||||
tmp_data = self.client_buffers[req.user]
|
||||
self.clear_states(req.user)
|
||||
result = {}
|
||||
result["success"] = True
|
||||
result["detail"] = "decoding data: %d bytes" % len(tmp_data)
|
||||
result["text"] = ""
|
||||
yield Response(sentence=json.dumps(result), user=req.user, action="decoding", language=req.language)
|
||||
if len(tmp_data) < 9600: #min input_len for asr model , 300ms
|
||||
end_time = int(round(time.time() * 1000))
|
||||
delay_str = str(end_time - begin_time)
|
||||
result = {}
|
||||
result["success"] = True
|
||||
result["detail"] = "waiting_for_more_voice"
|
||||
result["server_delay_ms"] = delay_str
|
||||
result["text"] = ""
|
||||
print ("user: %s , delay(ms): %s, info: %s " % (req.user, delay_str, "waiting_for_more_voice"))
|
||||
yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
|
||||
else:
|
||||
if self.backend == "pipeline":
|
||||
asr_result = self.inference_16k_pipeline(audio_in=tmp_data, audio_fs = self.sample_rate)
|
||||
if "text" in asr_result:
|
||||
asr_result = asr_result['text']
|
||||
else:
|
||||
asr_result = ""
|
||||
elif self.backend == "onnxruntime":
|
||||
from funasr_onnx.utils.frontend import load_bytes
|
||||
array = load_bytes(tmp_data)
|
||||
asr_result = self.inference_16k_pipeline(array)[0]
|
||||
end_time = int(round(time.time() * 1000))
|
||||
delay_str = str(end_time - begin_time)
|
||||
print ("user: %s , delay(ms): %s, text: %s " % (req.user, delay_str, asr_result))
|
||||
result = {}
|
||||
result["success"] = True
|
||||
result["detail"] = "finish_sentence"
|
||||
result["server_delay_ms"] = delay_str
|
||||
result["text"] = asr_result
|
||||
yield Response(sentence=json.dumps(result), user=req.user, action="finish", language=req.language)
|
||||
else:
|
||||
result = {}
|
||||
result["success"] = False
|
||||
result["detail"] = "error, no condition matched! Unknown reason."
|
||||
result["text"] = ""
|
||||
self.disconnect(req.user)
|
||||
yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)
|
||||
|
||||
|
||||
@ -1,30 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: paraformer.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10paraformer.proto\x12\nparaformer\"^\n\x07Request\x12\x12\n\naudio_data\x18\x01 \x01(\x0c\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x10\n\x08speaking\x18\x04 \x01(\x08\x12\r\n\x05isEnd\x18\x05 \x01(\x08\"L\n\x08Response\x12\x10\n\x08sentence\x18\x01 \x01(\t\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0e\n\x06\x61\x63tion\x18\x04 \x01(\t2C\n\x03\x41SR\x12<\n\tRecognize\x12\x13.paraformer.Request\x1a\x14.paraformer.Response\"\x00(\x01\x30\x01\x42\x16\n\x07\x65x.grpc\xa2\x02\nparaformerb\x06proto3')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'paraformer_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'\n\007ex.grpc\242\002\nparaformer'
|
||||
_REQUEST._serialized_start=32
|
||||
_REQUEST._serialized_end=126
|
||||
_RESPONSE._serialized_start=128
|
||||
_RESPONSE._serialized_end=204
|
||||
_ASR._serialized_start=206
|
||||
_ASR._serialized_end=273
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
@ -1,66 +0,0 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
|
||||
import paraformer_pb2 as paraformer__pb2
|
||||
|
||||
|
||||
class ASRStub(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.Recognize = channel.stream_stream(
|
||||
'/paraformer.ASR/Recognize',
|
||||
request_serializer=paraformer__pb2.Request.SerializeToString,
|
||||
response_deserializer=paraformer__pb2.Response.FromString,
|
||||
)
|
||||
|
||||
|
||||
class ASRServicer(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
def Recognize(self, request_iterator, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_ASRServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'Recognize': grpc.stream_stream_rpc_method_handler(
|
||||
servicer.Recognize,
|
||||
request_deserializer=paraformer__pb2.Request.FromString,
|
||||
response_serializer=paraformer__pb2.Response.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'paraformer.ASR', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class ASR(object):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
|
||||
@staticmethod
|
||||
def Recognize(request_iterator,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.stream_stream(request_iterator, target, '/paraformer.ASR/Recognize',
|
||||
paraformer__pb2.Request.SerializeToString,
|
||||
paraformer__pb2.Response.FromString,
|
||||
options, channel_credentials,
|
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
||||
@ -1,3 +1,8 @@
|
||||
// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
// Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
//
|
||||
// 2023 by burkliu(刘柏基) liubaiji@xverse.cn
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
option objc_class_prefix = "paraformer";
|
||||
@ -8,17 +13,27 @@ service ASR {
|
||||
rpc Recognize (stream Request) returns (stream Response) {}
|
||||
}
|
||||
|
||||
enum WavFormat {
|
||||
pcm = 0;
|
||||
}
|
||||
|
||||
enum DecodeMode {
|
||||
offline = 0;
|
||||
online = 1;
|
||||
two_pass = 2;
|
||||
}
|
||||
|
||||
message Request {
|
||||
bytes audio_data = 1;
|
||||
string user = 2;
|
||||
string language = 3;
|
||||
bool speaking = 4;
|
||||
bool isEnd = 5;
|
||||
DecodeMode mode = 1;
|
||||
WavFormat wav_format = 2;
|
||||
int32 sampling_rate = 3;
|
||||
repeated int32 chunk_size = 4;
|
||||
bool is_final = 5;
|
||||
bytes audio_data = 6;
|
||||
}
|
||||
|
||||
message Response {
|
||||
string sentence = 1;
|
||||
string user = 2;
|
||||
string language = 3;
|
||||
string action = 4;
|
||||
DecodeMode mode = 1;
|
||||
string text = 2;
|
||||
bool is_final = 3;
|
||||
}
|
||||
|
||||
@ -1,4 +1,2 @@
|
||||
pyaudio
|
||||
webrtcvad
|
||||
grpcio
|
||||
grpcio-tools
|
||||
@ -1,2 +0,0 @@
|
||||
grpcio
|
||||
grpcio-tools
|
||||
30
funasr/runtime/python/onnxruntime/demo_paraformer_online.py
Normal file
30
funasr/runtime/python/onnxruntime/demo_paraformer_online.py
Normal file
@ -0,0 +1,30 @@
|
||||
import soundfile
|
||||
from funasr_onnx.paraformer_online_bin import Paraformer
|
||||
from pathlib import Path
|
||||
|
||||
model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
|
||||
wav_path = '{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav'.format(Path.home())
|
||||
|
||||
chunk_size = [5, 10, 5]
|
||||
model = Paraformer(model_dir, batch_size=1, quantize=True, chunk_size=chunk_size, intra_op_num_threads=4) # only support batch_size = 1
|
||||
|
||||
##online asr
|
||||
speech, sample_rate = soundfile.read(wav_path)
|
||||
speech_length = speech.shape[0]
|
||||
sample_offset = 0
|
||||
step = chunk_size[1] * 960
|
||||
param_dict = {'cache': dict()}
|
||||
final_result = ""
|
||||
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
|
||||
if sample_offset + step >= speech_length - 1:
|
||||
step = speech_length - sample_offset
|
||||
is_final = True
|
||||
else:
|
||||
is_final = False
|
||||
param_dict['is_final'] = is_final
|
||||
rec_result = model(audio_in=speech[sample_offset: sample_offset + step],
|
||||
param_dict=param_dict)
|
||||
if len(rec_result) > 0:
|
||||
final_result += rec_result[0]["preds"][0]
|
||||
print(rec_result)
|
||||
print(final_result)
|
||||
@ -0,0 +1,309 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
import copy
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
|
||||
OrtInferSession, TokenIDConverter, get_logger,
|
||||
read_yaml)
|
||||
from .utils.postprocess_utils import sentence_postprocess
|
||||
from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class Paraformer():
|
||||
def __init__(self, model_dir: Union[str, Path] = None,
|
||||
batch_size: int = 1,
|
||||
chunk_size: List = [5, 10, 5],
|
||||
device_id: Union[str, int] = "-1",
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4,
|
||||
cache_dir: str = None
|
||||
):
|
||||
|
||||
if not Path(model_dir).exists():
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
try:
|
||||
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
|
||||
except:
|
||||
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
|
||||
|
||||
encoder_model_file = os.path.join(model_dir, 'model.onnx')
|
||||
decoder_model_file = os.path.join(model_dir, 'decoder.onnx')
|
||||
if quantize:
|
||||
encoder_model_file = os.path.join(model_dir, 'model_quant.onnx')
|
||||
decoder_model_file = os.path.join(model_dir, 'decoder_quant.onnx')
|
||||
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
|
||||
print(".onnx is not exist, begin to export onnx")
|
||||
from funasr.export.export_model import ModelExport
|
||||
export_model = ModelExport(
|
||||
cache_dir=cache_dir,
|
||||
onnx=True,
|
||||
device="cpu",
|
||||
quant=quantize,
|
||||
)
|
||||
export_model.export(model_dir)
|
||||
|
||||
config_file = os.path.join(model_dir, 'config.yaml')
|
||||
cmvn_file = os.path.join(model_dir, 'am.mvn')
|
||||
config = read_yaml(config_file)
|
||||
|
||||
self.converter = TokenIDConverter(config['token_list'])
|
||||
self.tokenizer = CharTokenizer()
|
||||
self.frontend = WavFrontendOnline(
|
||||
cmvn_file=cmvn_file,
|
||||
**config['frontend_conf']
|
||||
)
|
||||
self.pe = SinusoidalPositionEncoderOnline()
|
||||
self.ort_encoder_infer = OrtInferSession(encoder_model_file, device_id,
|
||||
intra_op_num_threads=intra_op_num_threads)
|
||||
self.ort_decoder_infer = OrtInferSession(decoder_model_file, device_id,
|
||||
intra_op_num_threads=intra_op_num_threads)
|
||||
self.batch_size = batch_size
|
||||
self.chunk_size = chunk_size
|
||||
self.encoder_output_size = config["encoder_conf"]["output_size"]
|
||||
self.fsmn_layer = config["decoder_conf"]["num_blocks"]
|
||||
self.fsmn_lorder = config["decoder_conf"]["kernel_size"] - 1
|
||||
self.fsmn_dims = config["encoder_conf"]["output_size"]
|
||||
self.feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
|
||||
self.cif_threshold = config["predictor_conf"]["threshold"]
|
||||
self.tail_threshold = config["predictor_conf"]["tail_threshold"]
|
||||
|
||||
def prepare_cache(self, cache: dict = {}, batch_size=1):
|
||||
if len(cache) > 0:
|
||||
return cache
|
||||
cache["start_idx"] = 0
|
||||
cache["cif_hidden"] = np.zeros((batch_size, 1, self.encoder_output_size)).astype(np.float32)
|
||||
cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
|
||||
cache["chunk_size"] = self.chunk_size
|
||||
cache["last_chunk"] = False
|
||||
cache["feats"] = np.zeros((batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)).astype(np.float32)
|
||||
cache["decoder_fsmn"] = []
|
||||
for i in range(self.fsmn_layer):
|
||||
fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
|
||||
cache["decoder_fsmn"].append(fsmn_cache)
|
||||
return cache
|
||||
|
||||
def add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
|
||||
if len(cache) == 0:
|
||||
return feats
|
||||
# process last chunk
|
||||
overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
|
||||
if cache["is_final"]:
|
||||
cache["feats"] = overlap_feats[:, -self.chunk_size[0]:, :]
|
||||
if not cache["last_chunk"]:
|
||||
padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
|
||||
overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
|
||||
else:
|
||||
cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]):, :]
|
||||
return overlap_feats
|
||||
|
||||
def __call__(self, audio_in: np.ndarray, **kwargs):
|
||||
waveforms = np.expand_dims(audio_in, axis=0)
|
||||
param_dict = kwargs.get('param_dict', dict())
|
||||
is_final = param_dict.get('is_final', False)
|
||||
cache = param_dict.get('cache', dict())
|
||||
asr_res = []
|
||||
|
||||
if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
|
||||
cache["last_chunk"] = True
|
||||
feats = cache["feats"]
|
||||
feats_len = np.array([feats.shape[1]]).astype(np.int32)
|
||||
asr_res = self.infer(feats, feats_len, cache)
|
||||
return asr_res
|
||||
|
||||
feats, feats_len = self.extract_feat(waveforms, is_final)
|
||||
if feats.shape[1] != 0:
|
||||
feats *= self.encoder_output_size ** 0.5
|
||||
cache = self.prepare_cache(cache)
|
||||
cache["is_final"] = is_final
|
||||
|
||||
# fbank -> position encoding -> overlap chunk
|
||||
feats = self.pe.forward(feats, cache["start_idx"])
|
||||
cache["start_idx"] += feats.shape[1]
|
||||
if is_final:
|
||||
if feats.shape[1] + self.chunk_size[2] <= self.chunk_size[1]:
|
||||
cache["last_chunk"] = True
|
||||
feats = self.add_overlap_chunk(feats, cache)
|
||||
else:
|
||||
# first chunk
|
||||
feats_chunk1 = self.add_overlap_chunk(feats[:, :self.chunk_size[1], :], cache)
|
||||
feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
|
||||
asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
|
||||
|
||||
# last chunk
|
||||
cache["last_chunk"] = True
|
||||
feats_chunk2 = self.add_overlap_chunk(feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]):, :], cache)
|
||||
feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
|
||||
asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
|
||||
|
||||
asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
|
||||
res = {}
|
||||
for pred in asr_res_chunk:
|
||||
for key, value in pred.items():
|
||||
if key in res:
|
||||
res[key][0] += value[0]
|
||||
res[key][1].extend(value[1])
|
||||
else:
|
||||
res[key] = [value[0], value[1]]
|
||||
return [res]
|
||||
else:
|
||||
feats = self.add_overlap_chunk(feats, cache)
|
||||
|
||||
feats_len = np.array([feats.shape[1]]).astype(np.int32)
|
||||
asr_res = self.infer(feats, feats_len, cache)
|
||||
|
||||
return asr_res
|
||||
|
||||
def infer(self, feats: np.ndarray, feats_len: np.ndarray, cache):
|
||||
# encoder forward
|
||||
enc_input = [feats, feats_len]
|
||||
enc, enc_lens, cif_alphas = self.ort_encoder_infer(enc_input)
|
||||
|
||||
# predictor forward
|
||||
acoustic_embeds, acoustic_embeds_len = self.cif_search(enc, cif_alphas, cache)
|
||||
|
||||
# decoder forward
|
||||
asr_res = []
|
||||
if acoustic_embeds.shape[1] > 0:
|
||||
dec_input = [enc, enc_lens, acoustic_embeds, acoustic_embeds_len]
|
||||
dec_input.extend(cache["decoder_fsmn"])
|
||||
dec_output = self.ort_decoder_infer(dec_input)
|
||||
logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
|
||||
cache["decoder_fsmn"] = [item[:, :, -self.fsmn_lorder:] for item in cache["decoder_fsmn"]]
|
||||
|
||||
preds = self.decode(logits, acoustic_embeds_len)
|
||||
for pred in preds:
|
||||
pred = sentence_postprocess(pred)
|
||||
asr_res.append({'preds': pred})
|
||||
|
||||
return asr_res
|
||||
|
||||
def load_data(self,
|
||||
wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
|
||||
def load_wav(path: str) -> np.ndarray:
|
||||
waveform, _ = librosa.load(path, sr=fs)
|
||||
return waveform
|
||||
|
||||
if isinstance(wav_content, np.ndarray):
|
||||
return [wav_content]
|
||||
|
||||
if isinstance(wav_content, str):
|
||||
return [load_wav(wav_content)]
|
||||
|
||||
if isinstance(wav_content, list):
|
||||
return [load_wav(path) for path in wav_content]
|
||||
|
||||
raise TypeError(
|
||||
f'The type of {wav_content} is not in [str, np.ndarray, list]')
|
||||
|
||||
def extract_feat(self,
|
||||
waveforms: np.ndarray, is_final: bool = False
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
|
||||
for idx, waveform in enumerate(waveforms):
|
||||
waveforms_lens[idx] = waveform.shape[-1]
|
||||
|
||||
feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
|
||||
return feats.astype(np.float32), feats_len.astype(np.int32)
|
||||
|
||||
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
||||
return [self.decode_one(am_score, token_num)
|
||||
for am_score, token_num in zip(am_scores, token_nums)]
|
||||
|
||||
def decode_one(self,
|
||||
am_score: np.ndarray,
|
||||
valid_token_num: int) -> List[str]:
|
||||
yseq = am_score.argmax(axis=-1)
|
||||
score = am_score.max(axis=-1)
|
||||
score = np.sum(score, axis=-1)
|
||||
|
||||
# pad with mask tokens to ensure compatibility with sos/eos tokens
|
||||
# asr_model.sos:1 asr_model.eos:2
|
||||
yseq = np.array([1] + yseq.tolist() + [2])
|
||||
hyp = Hypothesis(yseq=yseq, score=score)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x not in (0, 2), token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
token = token[:valid_token_num]
|
||||
# texts = sentence_postprocess(token)
|
||||
return token
|
||||
|
||||
def cif_search(self, hidden, alphas, cache=None):
|
||||
batch_size, len_time, hidden_size = hidden.shape
|
||||
token_length = []
|
||||
list_fires = []
|
||||
list_frames = []
|
||||
cache_alphas = []
|
||||
cache_hiddens = []
|
||||
alphas[:, :self.chunk_size[0]] = 0.0
|
||||
alphas[:, sum(self.chunk_size[:2]):] = 0.0
|
||||
if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
|
||||
hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
|
||||
alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
|
||||
if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
|
||||
tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
|
||||
tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
|
||||
tail_alphas =np.tile(tail_alphas, (batch_size, 1))
|
||||
hidden = np.concatenate((hidden, tail_hidden), axis=1)
|
||||
alphas = np.concatenate((alphas, tail_alphas), axis=1)
|
||||
|
||||
len_time = alphas.shape[1]
|
||||
for b in range(batch_size):
|
||||
integrate = 0.0
|
||||
frames = np.zeros(hidden_size).astype(np.float32)
|
||||
list_frame = []
|
||||
list_fire = []
|
||||
for t in range(len_time):
|
||||
alpha = alphas[b][t]
|
||||
if alpha + integrate < self.cif_threshold:
|
||||
integrate += alpha
|
||||
list_fire.append(integrate)
|
||||
frames += alpha * hidden[b][t]
|
||||
else:
|
||||
frames += (self.cif_threshold - integrate) * hidden[b][t]
|
||||
list_frame.append(frames)
|
||||
integrate += alpha
|
||||
list_fire.append(integrate)
|
||||
integrate -= self.cif_threshold
|
||||
frames = integrate * hidden[b][t]
|
||||
|
||||
cache_alphas.append(integrate)
|
||||
if integrate > 0.0:
|
||||
cache_hiddens.append(frames / integrate)
|
||||
else:
|
||||
cache_hiddens.append(frames)
|
||||
|
||||
token_length.append(len(list_frame))
|
||||
list_fires.append(list_fire)
|
||||
list_frames.append(list_frame)
|
||||
|
||||
max_token_len = max(token_length)
|
||||
list_ls = []
|
||||
for b in range(batch_size):
|
||||
pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
|
||||
if token_length[b] == 0:
|
||||
list_ls.append(pad_frames)
|
||||
else:
|
||||
list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
|
||||
|
||||
cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
|
||||
cache["cif_alphas"] = np.expand_dims(cache["cif_alphas"], axis=0)
|
||||
cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
|
||||
cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
|
||||
|
||||
return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)
|
||||
|
||||
@ -349,6 +349,28 @@ def load_bytes(input):
|
||||
return array
|
||||
|
||||
|
||||
class SinusoidalPositionEncoderOnline():
|
||||
'''Streaming Positional encoding.
|
||||
'''
|
||||
|
||||
def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
|
||||
batch_size = positions.shape[0]
|
||||
positions = positions.astype(dtype)
|
||||
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
|
||||
inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
|
||||
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
|
||||
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
|
||||
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
|
||||
return encoding.astype(dtype)
|
||||
|
||||
def forward(self, x, start_idx=0):
|
||||
batch_size, timesteps, input_dim = x.shape
|
||||
positions = np.arange(1, timesteps+1+start_idx)[None, :]
|
||||
position_encoding = self.encode(positions, input_dim, x.dtype)
|
||||
|
||||
return x + position_encoding[:, start_idx: start_idx + timesteps]
|
||||
|
||||
|
||||
def test():
|
||||
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
|
||||
import librosa
|
||||
|
||||
@ -58,7 +58,11 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
|
||||
find_package(OpenSSL REQUIRED)
|
||||
|
||||
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
|
||||
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp")
|
||||
add_executable(funasr-wss-client "funasr-wss-client.cpp")
|
||||
add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp")
|
||||
|
||||
target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
|
||||
target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ssl crypto)
|
||||
target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)
|
||||
target_link_libraries(funasr-wss-server-2pass PUBLIC funasr ssl crypto)
|
||||
|
||||
430
funasr/runtime/websocket/funasr-wss-client-2pass.cpp
Normal file
430
funasr/runtime/websocket/funasr-wss-client-2pass.cpp
Normal file
@ -0,0 +1,430 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2022-2023 by zhaomingwork */
|
||||
|
||||
// client for websocket, support multiple threads
|
||||
// ./funasr-wss-client --server-ip <string>
|
||||
// --port <string>
|
||||
// --wav-path <string>
|
||||
// [--thread-num <int>]
|
||||
// [--is-ssl <int>] [--]
|
||||
// [--version] [-h]
|
||||
// example:
|
||||
// ./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
|
||||
|
||||
#define ASIO_STANDALONE 1
|
||||
#include <websocketpp/client.hpp>
|
||||
#include <websocketpp/common/thread.hpp>
|
||||
#include <websocketpp/config/asio_client.hpp>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include "audio.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "tclap/CmdLine.h"
|
||||
|
||||
/**
|
||||
* Define a semi-cross platform helper method that waits/sleeps for a bit.
|
||||
*/
|
||||
void WaitABit() {
|
||||
#ifdef WIN32
|
||||
Sleep(1000);
|
||||
#else
|
||||
sleep(1);
|
||||
#endif
|
||||
}
|
||||
std::atomic<int> wav_index(0);
|
||||
|
||||
bool IsTargetFile(const std::string& filename, const std::string target) {
|
||||
std::size_t pos = filename.find_last_of(".");
|
||||
if (pos == std::string::npos) {
|
||||
return false;
|
||||
}
|
||||
std::string extension = filename.substr(pos + 1);
|
||||
return (extension == target);
|
||||
}
|
||||
|
||||
typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
|
||||
typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context> context_ptr;
|
||||
using websocketpp::lib::bind;
|
||||
using websocketpp::lib::placeholders::_1;
|
||||
using websocketpp::lib::placeholders::_2;
|
||||
context_ptr OnTlsInit(websocketpp::connection_hdl) {
|
||||
context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
|
||||
asio::ssl::context::sslv23);
|
||||
|
||||
try {
|
||||
ctx->set_options(
|
||||
asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
|
||||
asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
|
||||
|
||||
} catch (std::exception& e) {
|
||||
LOG(ERROR) << e.what();
|
||||
}
|
||||
return ctx;
|
||||
}
|
||||
|
||||
// template for tls or not config
|
||||
template <typename T>
|
||||
class WebsocketClient {
|
||||
public:
|
||||
// typedef websocketpp::client<T> client;
|
||||
// typedef websocketpp::client<websocketpp::config::asio_tls_client>
|
||||
// wss_client;
|
||||
typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
|
||||
|
||||
WebsocketClient(int is_ssl) : m_open(false), m_done(false) {
|
||||
// set up access channels to only log interesting things
|
||||
m_client.clear_access_channels(websocketpp::log::alevel::all);
|
||||
m_client.set_access_channels(websocketpp::log::alevel::connect);
|
||||
m_client.set_access_channels(websocketpp::log::alevel::disconnect);
|
||||
m_client.set_access_channels(websocketpp::log::alevel::app);
|
||||
|
||||
// Initialize the Asio transport policy
|
||||
m_client.init_asio();
|
||||
|
||||
// Bind the handlers we are using
|
||||
using websocketpp::lib::bind;
|
||||
using websocketpp::lib::placeholders::_1;
|
||||
m_client.set_open_handler(bind(&WebsocketClient::on_open, this, _1));
|
||||
m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
|
||||
|
||||
m_client.set_message_handler(
|
||||
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
|
||||
on_message(hdl, msg);
|
||||
});
|
||||
|
||||
m_client.set_fail_handler(bind(&WebsocketClient::on_fail, this, _1));
|
||||
m_client.clear_access_channels(websocketpp::log::alevel::all);
|
||||
}
|
||||
|
||||
void on_message(websocketpp::connection_hdl hdl, message_ptr msg) {
|
||||
const std::string& payload = msg->get_payload();
|
||||
switch (msg->get_opcode()) {
|
||||
case websocketpp::frame::opcode::text:
|
||||
nlohmann::json jsonresult = nlohmann::json::parse(payload);
|
||||
LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload;
|
||||
|
||||
// if (jsonresult["is_final"] == true){
|
||||
// websocketpp::lib::error_code ec;
|
||||
// m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
|
||||
// if (ec){
|
||||
// LOG(ERROR)<< "Error closing connection " << ec.message();
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
// This method will block until the connection is complete
|
||||
void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string asr_mode, std::vector<int> chunk_size) {
|
||||
// Create a new connection to the given URI
|
||||
websocketpp::lib::error_code ec;
|
||||
typename websocketpp::client<T>::connection_ptr con =
|
||||
m_client.get_connection(uri, ec);
|
||||
if (ec) {
|
||||
m_client.get_alog().write(websocketpp::log::alevel::app,
|
||||
"Get Connection Error: " + ec.message());
|
||||
return;
|
||||
}
|
||||
// Grab a handle for this connection so we can talk to it in a thread
|
||||
// safe manor after the event loop starts.
|
||||
m_hdl = con->get_handle();
|
||||
|
||||
// Queue the connection. No DNS queries or network connections will be
|
||||
// made until the io_service event loop is run.
|
||||
m_client.connect(con);
|
||||
|
||||
// Create a thread to run the ASIO io_service event loop
|
||||
websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
|
||||
&m_client);
|
||||
while(true){
|
||||
int i = wav_index.fetch_add(1);
|
||||
if (i >= wav_list.size()) {
|
||||
break;
|
||||
}
|
||||
send_wav_data(wav_list[i], wav_ids[i], asr_mode, chunk_size);
|
||||
}
|
||||
WaitABit();
|
||||
|
||||
asio_thread.join();
|
||||
|
||||
}
|
||||
|
||||
// The open handler will signal that we are ready to start sending data
|
||||
void on_open(websocketpp::connection_hdl) {
|
||||
m_client.get_alog().write(websocketpp::log::alevel::app,
|
||||
"Connection opened, starting data!");
|
||||
|
||||
scoped_lock guard(m_lock);
|
||||
m_open = true;
|
||||
}
|
||||
|
||||
// The close handler will signal that we should stop sending data
|
||||
void on_close(websocketpp::connection_hdl) {
|
||||
m_client.get_alog().write(websocketpp::log::alevel::app,
|
||||
"Connection closed, stopping data!");
|
||||
|
||||
scoped_lock guard(m_lock);
|
||||
m_done = true;
|
||||
}
|
||||
|
||||
// The fail handler will signal that we should stop sending data
|
||||
void on_fail(websocketpp::connection_hdl) {
|
||||
m_client.get_alog().write(websocketpp::log::alevel::app,
|
||||
"Connection failed, stopping data!");
|
||||
|
||||
scoped_lock guard(m_lock);
|
||||
m_done = true;
|
||||
}
|
||||
// send wav to server
|
||||
void send_wav_data(string wav_path, string wav_id, std::string asr_mode, std::vector<int> chunk_vector) {
|
||||
uint64_t count = 0;
|
||||
std::stringstream val;
|
||||
|
||||
funasr::Audio audio(1);
|
||||
int32_t sampling_rate = 16000;
|
||||
std::string wav_format = "pcm";
|
||||
if(IsTargetFile(wav_path.c_str(), "wav")){
|
||||
int32_t sampling_rate = -1;
|
||||
if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
|
||||
return ;
|
||||
}else if(IsTargetFile(wav_path.c_str(), "pcm")){
|
||||
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
|
||||
return ;
|
||||
}else{
|
||||
wav_format = "others";
|
||||
if (!audio.LoadOthers2Char(wav_path.c_str()))
|
||||
return ;
|
||||
}
|
||||
|
||||
float* buff;
|
||||
int len;
|
||||
int flag = 0;
|
||||
bool wait = false;
|
||||
while (1) {
|
||||
{
|
||||
scoped_lock guard(m_lock);
|
||||
// If the connection has been closed, stop generating data
|
||||
if (m_done) {
|
||||
break;
|
||||
}
|
||||
// If the connection hasn't been opened yet wait a bit and retry
|
||||
if (!m_open) {
|
||||
wait = true;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (wait) {
|
||||
// LOG(INFO) << "wait.." << m_open;
|
||||
WaitABit();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
websocketpp::lib::error_code ec;
|
||||
|
||||
nlohmann::json jsonbegin;
|
||||
nlohmann::json chunk_size = nlohmann::json::array();
|
||||
chunk_size.push_back(chunk_vector[0]);
|
||||
chunk_size.push_back(chunk_vector[1]);
|
||||
chunk_size.push_back(chunk_vector[2]);
|
||||
jsonbegin["mode"] = asr_mode;
|
||||
jsonbegin["chunk_size"] = chunk_size;
|
||||
jsonbegin["wav_name"] = wav_id;
|
||||
jsonbegin["wav_format"] = wav_format;
|
||||
jsonbegin["is_speaking"] = true;
|
||||
m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
|
||||
ec);
|
||||
|
||||
// fetch wav data use asr engine api
|
||||
if(wav_format == "pcm"){
|
||||
while (audio.Fetch(buff, len, flag) > 0) {
|
||||
short* iArray = new short[len];
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
iArray[i] = (short)(buff[i]*32768);
|
||||
}
|
||||
|
||||
// send data to server
|
||||
int offset = 0;
|
||||
int block_size = 102400;
|
||||
while(offset < len){
|
||||
int send_block = 0;
|
||||
if (offset + block_size <= len){
|
||||
send_block = block_size;
|
||||
}else{
|
||||
send_block = len - offset;
|
||||
}
|
||||
m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
|
||||
websocketpp::frame::opcode::binary, ec);
|
||||
offset += send_block;
|
||||
}
|
||||
|
||||
LOG(INFO) << "sended data len=" << len * sizeof(short);
|
||||
// The most likely error that we will get is that the connection is
|
||||
// not in the right state. Usually this means we tried to send a
|
||||
// message to a connection that was closed or in the process of
|
||||
// closing. While many errors here can be easily recovered from,
|
||||
// in this simple example, we'll stop the data loop.
|
||||
if (ec) {
|
||||
m_client.get_alog().write(websocketpp::log::alevel::app,
|
||||
"Send Error: " + ec.message());
|
||||
break;
|
||||
}
|
||||
delete[] iArray;
|
||||
// WaitABit();
|
||||
}
|
||||
}else{
|
||||
int offset = 0;
|
||||
int block_size = 204800;
|
||||
len = audio.GetSpeechLen();
|
||||
char* others_buff = audio.GetSpeechChar();
|
||||
|
||||
while(offset < len){
|
||||
int send_block = 0;
|
||||
if (offset + block_size <= len){
|
||||
send_block = block_size;
|
||||
}else{
|
||||
send_block = len - offset;
|
||||
}
|
||||
m_client.send(m_hdl, others_buff+offset, send_block,
|
||||
websocketpp::frame::opcode::binary, ec);
|
||||
offset += send_block;
|
||||
}
|
||||
|
||||
LOG(INFO) << "sended data len=" << len;
|
||||
// The most likely error that we will get is that the connection is
|
||||
// not in the right state. Usually this means we tried to send a
|
||||
// message to a connection that was closed or in the process of
|
||||
// closing. While many errors here can be easily recovered from,
|
||||
// in this simple example, we'll stop the data loop.
|
||||
if (ec) {
|
||||
m_client.get_alog().write(websocketpp::log::alevel::app,
|
||||
"Send Error: " + ec.message());
|
||||
}
|
||||
}
|
||||
|
||||
nlohmann::json jsonresult;
|
||||
jsonresult["is_speaking"] = false;
|
||||
m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
|
||||
ec);
|
||||
// WaitABit();
|
||||
}
|
||||
websocketpp::client<T> m_client;
|
||||
|
||||
private:
|
||||
websocketpp::connection_hdl m_hdl;
|
||||
websocketpp::lib::mutex m_lock;
|
||||
bool m_open;
|
||||
bool m_done;
|
||||
int total_num=0;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
|
||||
"127.0.0.1", "string");
|
||||
TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095", "string");
|
||||
TCLAP::ValueArg<std::string> wav_path_("", "wav-path",
|
||||
"the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
|
||||
true, "", "string");
|
||||
TCLAP::ValueArg<std::string> asr_mode_("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
|
||||
TCLAP::ValueArg<std::string> chunk_size_("", "chunk-size", "chunk_size: 5-10-5 or 5-12-5", false, "5-10-5", "string");
|
||||
TCLAP::ValueArg<int> thread_num_("", "thread-num", "thread-num",
|
||||
false, 1, "int");
|
||||
TCLAP::ValueArg<int> is_ssl_(
|
||||
"", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
|
||||
false, 1, "int");
|
||||
|
||||
cmd.add(server_ip_);
|
||||
cmd.add(port_);
|
||||
cmd.add(wav_path_);
|
||||
cmd.add(asr_mode_);
|
||||
cmd.add(chunk_size_);
|
||||
cmd.add(thread_num_);
|
||||
cmd.add(is_ssl_);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::string server_ip = server_ip_.getValue();
|
||||
std::string port = port_.getValue();
|
||||
std::string wav_path = wav_path_.getValue();
|
||||
std::string asr_mode = asr_mode_.getValue();
|
||||
std::string chunk_size_str = chunk_size_.getValue();
|
||||
// get chunk_size
|
||||
std::vector<int> chunk_size;
|
||||
std::stringstream ss(chunk_size_str);
|
||||
std::string item;
|
||||
while (std::getline(ss, item, '-')) {
|
||||
try {
|
||||
chunk_size.push_back(stoi(item));
|
||||
} catch (const invalid_argument&) {
|
||||
LOG(ERROR) << "Invalid argument: " << item;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
int threads_num = thread_num_.getValue();
|
||||
int is_ssl = is_ssl_.getValue();
|
||||
|
||||
std::vector<websocketpp::lib::thread> client_threads;
|
||||
std::string uri = "";
|
||||
if (is_ssl == 1) {
|
||||
uri = "wss://" + server_ip + ":" + port;
|
||||
} else {
|
||||
uri = "ws://" + server_ip + ":" + port;
|
||||
}
|
||||
|
||||
// read wav_path
|
||||
std::vector<string> wav_list;
|
||||
std::vector<string> wav_ids;
|
||||
string default_id = "wav_default_id";
|
||||
if(IsTargetFile(wav_path, "scp")){
|
||||
ifstream in(wav_path);
|
||||
if (!in.is_open()) {
|
||||
printf("Failed to open scp file");
|
||||
return 0;
|
||||
}
|
||||
string line;
|
||||
while(getline(in, line))
|
||||
{
|
||||
istringstream iss(line);
|
||||
string column1, column2;
|
||||
iss >> column1 >> column2;
|
||||
wav_list.emplace_back(column2);
|
||||
wav_ids.emplace_back(column1);
|
||||
}
|
||||
in.close();
|
||||
}else{
|
||||
wav_list.emplace_back(wav_path);
|
||||
wav_ids.emplace_back(default_id);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < threads_num; i++) {
|
||||
client_threads.emplace_back([uri, wav_list, wav_ids, asr_mode, chunk_size, is_ssl]() {
|
||||
if (is_ssl == 1) {
|
||||
WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
|
||||
|
||||
c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
|
||||
|
||||
c.run(uri, wav_list, wav_ids, asr_mode, chunk_size);
|
||||
} else {
|
||||
WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
|
||||
|
||||
c.run(uri, wav_list, wav_ids, asr_mode, chunk_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : client_threads) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
419
funasr/runtime/websocket/funasr-wss-server-2pass.cpp
Normal file
419
funasr/runtime/websocket/funasr-wss-server-2pass.cpp
Normal file
@ -0,0 +1,419 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2022-2023 by zhaomingwork */
|
||||
|
||||
// io server
|
||||
// Usage:funasr-wss-server [--model_thread_num <int>] [--decoder_thread_num
|
||||
// <int>]
|
||||
// [--io_thread_num <int>] [--port <int>] [--listen_ip
|
||||
// <string>] [--punc-quant <string>] [--punc-dir <string>]
|
||||
// [--vad-quant <string>] [--vad-dir <string>] [--quantize
|
||||
// <string>] --model-dir <string> [--] [--version] [-h]
|
||||
#include <unistd.h>
|
||||
#include "websocket-server-2pass.h"
|
||||
|
||||
using namespace std;
|
||||
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
|
||||
std::map<std::string, std::string>& model_path) {
|
||||
model_path.insert({key, value_arg.getValue()});
|
||||
LOG(INFO) << key << " : " << value_arg.getValue();
|
||||
}
|
||||
int main(int argc, char* argv[]) {
|
||||
try {
|
||||
google::InitGoogleLogging(argv[0]);
|
||||
FLAGS_logtostderr = true;
|
||||
|
||||
TCLAP::CmdLine cmd("funasr-wss-server", ' ', "1.0");
|
||||
TCLAP::ValueArg<std::string> download_model_dir(
|
||||
"", "download-model-dir",
|
||||
"Download model from Modelscope to download_model_dir", false,
|
||||
"/workspace/models", "string");
|
||||
TCLAP::ValueArg<std::string> offline_model_dir(
|
||||
"", OFFLINE_MODEL_DIR,
|
||||
"default: /workspace/models/offline_asr, the asr model path, which "
|
||||
"contains model_quant.onnx, config.yaml, am.mvn",
|
||||
false, "/workspace/models/offline_asr", "string");
|
||||
TCLAP::ValueArg<std::string> online_model_dir(
|
||||
"", ONLINE_MODEL_DIR,
|
||||
"default: damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx, the asr model path, which "
|
||||
"contains model_quant.onnx, config.yaml, am.mvn",
|
||||
false, "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> offline_model_revision(
|
||||
"", "offline-model-revision", "ASR offline model revision", false,
|
||||
"v1.2.1", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> online_model_revision(
|
||||
"", "online-model-revision", "ASR online model revision", false,
|
||||
"v1.0.6", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> quantize(
|
||||
"", QUANTIZE,
|
||||
"true (Default), load the model of model_quant.onnx in model_dir. If "
|
||||
"set "
|
||||
"false, load the model of model.onnx in model_dir",
|
||||
false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> vad_dir(
|
||||
"", VAD_DIR,
|
||||
"default: /workspace/models/vad, the vad model path, which contains "
|
||||
"model_quant.onnx, vad.yaml, vad.mvn",
|
||||
false, "/workspace/models/vad", "string");
|
||||
TCLAP::ValueArg<std::string> vad_revision(
|
||||
"", "vad-revision", "VAD model revision", false, "v1.2.0", "string");
|
||||
TCLAP::ValueArg<std::string> vad_quant(
|
||||
"", VAD_QUANT,
|
||||
"true (Default), load the model of model_quant.onnx in vad_dir. If set "
|
||||
"false, load the model of model.onnx in vad_dir",
|
||||
false, "true", "string");
|
||||
TCLAP::ValueArg<std::string> punc_dir(
|
||||
"", PUNC_DIR,
|
||||
"default: /workspace/models/punc, the punc model path, which contains "
|
||||
"model_quant.onnx, punc.yaml",
|
||||
false, "/workspace/models/punc", "string");
|
||||
TCLAP::ValueArg<std::string> punc_revision(
|
||||
"", "punc-revision", "PUNC model revision", false, "v1.0.2", "string");
|
||||
TCLAP::ValueArg<std::string> punc_quant(
|
||||
"", PUNC_QUANT,
|
||||
"true (Default), load the model of model_quant.onnx in punc_dir. If "
|
||||
"set "
|
||||
"false, load the model of model.onnx in punc_dir",
|
||||
false, "true", "string");
|
||||
|
||||
TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false,
|
||||
"0.0.0.0", "string");
|
||||
TCLAP::ValueArg<int> port("", "port", "port", false, 10095, "int");
|
||||
TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num",
|
||||
false, 8, "int");
|
||||
TCLAP::ValueArg<int> decoder_thread_num(
|
||||
"", "decoder-thread-num", "decoder thread num", false, 8, "int");
|
||||
TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
|
||||
"model thread num", false, 4, "int");
|
||||
|
||||
TCLAP::ValueArg<std::string> certfile(
|
||||
"", "certfile",
|
||||
"default: ../../../ssl_key/server.crt, path of certficate for WSS "
|
||||
"connection. if it is empty, it will be in WS mode.",
|
||||
false, "../../../ssl_key/server.crt", "string");
|
||||
TCLAP::ValueArg<std::string> keyfile(
|
||||
"", "keyfile",
|
||||
"default: ../../../ssl_key/server.key, path of keyfile for WSS "
|
||||
"connection",
|
||||
false, "../../../ssl_key/server.key", "string");
|
||||
|
||||
cmd.add(certfile);
|
||||
cmd.add(keyfile);
|
||||
|
||||
cmd.add(download_model_dir);
|
||||
cmd.add(offline_model_dir);
|
||||
cmd.add(online_model_dir);
|
||||
cmd.add(offline_model_revision);
|
||||
cmd.add(online_model_revision);
|
||||
cmd.add(quantize);
|
||||
cmd.add(vad_dir);
|
||||
cmd.add(vad_revision);
|
||||
cmd.add(vad_quant);
|
||||
cmd.add(punc_dir);
|
||||
cmd.add(punc_revision);
|
||||
cmd.add(punc_quant);
|
||||
|
||||
cmd.add(listen_ip);
|
||||
cmd.add(port);
|
||||
cmd.add(io_thread_num);
|
||||
cmd.add(decoder_thread_num);
|
||||
cmd.add(model_thread_num);
|
||||
cmd.parse(argc, argv);
|
||||
|
||||
std::map<std::string, std::string> model_path;
|
||||
GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
|
||||
GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
|
||||
GetValue(quantize, QUANTIZE, model_path);
|
||||
GetValue(vad_dir, VAD_DIR, model_path);
|
||||
GetValue(vad_quant, VAD_QUANT, model_path);
|
||||
GetValue(punc_dir, PUNC_DIR, model_path);
|
||||
GetValue(punc_quant, PUNC_QUANT, model_path);
|
||||
|
||||
GetValue(offline_model_revision, "offline-model-revision", model_path);
|
||||
GetValue(online_model_revision, "online-model-revision", model_path);
|
||||
GetValue(vad_revision, "vad-revision", model_path);
|
||||
GetValue(punc_revision, "punc-revision", model_path);
|
||||
|
||||
// Download model form Modelscope
|
||||
try {
|
||||
std::string s_download_model_dir = download_model_dir.getValue();
|
||||
|
||||
std::string s_vad_path = model_path[VAD_DIR];
|
||||
std::string s_vad_quant = model_path[VAD_QUANT];
|
||||
std::string s_offline_asr_path = model_path[OFFLINE_MODEL_DIR];
|
||||
std::string s_online_asr_path = model_path[ONLINE_MODEL_DIR];
|
||||
std::string s_asr_quant = model_path[QUANTIZE];
|
||||
std::string s_punc_path = model_path[PUNC_DIR];
|
||||
std::string s_punc_quant = model_path[PUNC_QUANT];
|
||||
|
||||
std::string python_cmd =
|
||||
"python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True ";
|
||||
|
||||
if (vad_dir.isSet() && !s_vad_path.empty()) {
|
||||
std::string python_cmd_vad;
|
||||
std::string down_vad_path;
|
||||
std::string down_vad_model;
|
||||
|
||||
if (access(s_vad_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["vad-revision"];
|
||||
down_vad_path = s_vad_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_vad_path
|
||||
<< " from modelscope: ";
|
||||
python_cmd_vad = python_cmd + " --model-name " +
|
||||
s_vad_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["vad-revision"];
|
||||
down_vad_path =
|
||||
s_download_model_dir +
|
||||
"/" + s_vad_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_vad.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set local vad model path, you can ignore the errors.";
|
||||
}
|
||||
down_vad_model = down_vad_path + "/model_quant.onnx";
|
||||
if (s_vad_quant == "false" || s_vad_quant == "False" ||
|
||||
s_vad_quant == "FALSE") {
|
||||
down_vad_model = down_vad_path + "/model.onnx";
|
||||
}
|
||||
|
||||
if (access(down_vad_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_vad_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[VAD_DIR] = down_vad_path;
|
||||
LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
|
||||
}
|
||||
}
|
||||
else {
|
||||
LOG(INFO) << "VAD model is not set, use default.";
|
||||
}
|
||||
|
||||
if (offline_model_dir.isSet() && !s_offline_asr_path.empty()) {
|
||||
std::string python_cmd_asr;
|
||||
std::string down_asr_path;
|
||||
std::string down_asr_model;
|
||||
|
||||
if (access(s_offline_asr_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_asr = python_cmd + " --model-name " + s_offline_asr_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["offline-model-revision"];
|
||||
down_asr_path = s_offline_asr_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_offline_asr_path
|
||||
<< " from modelscope : ";
|
||||
python_cmd_asr = python_cmd + " --model-name " +
|
||||
s_offline_asr_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["offline-model-revision"];
|
||||
down_asr_path
|
||||
= s_download_model_dir + "/" + s_offline_asr_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_asr.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
|
||||
}
|
||||
down_asr_model = down_asr_path + "/model_quant.onnx";
|
||||
if (s_asr_quant == "false" || s_asr_quant == "False" ||
|
||||
s_asr_quant == "FALSE") {
|
||||
down_asr_model = down_asr_path + "/model.onnx";
|
||||
}
|
||||
|
||||
if (access(down_asr_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_asr_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[OFFLINE_MODEL_DIR] = down_asr_path;
|
||||
LOG(INFO) << "Set " << OFFLINE_MODEL_DIR << " : " << model_path[OFFLINE_MODEL_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "ASR Offline model is not set, use default.";
|
||||
}
|
||||
|
||||
if (!s_online_asr_path.empty()) {
|
||||
std::string python_cmd_asr;
|
||||
std::string down_asr_path;
|
||||
std::string down_asr_model;
|
||||
|
||||
if (access(s_online_asr_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_asr = python_cmd + " --model-name " + s_online_asr_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["online-model-revision"];
|
||||
down_asr_path = s_online_asr_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_online_asr_path
|
||||
<< " from modelscope : ";
|
||||
python_cmd_asr = python_cmd + " --model-name " +
|
||||
s_online_asr_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["online-model-revision"];
|
||||
down_asr_path
|
||||
= s_download_model_dir + "/" + s_online_asr_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_asr.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
|
||||
}
|
||||
down_asr_model = down_asr_path + "/model_quant.onnx";
|
||||
if (s_asr_quant == "false" || s_asr_quant == "False" ||
|
||||
s_asr_quant == "FALSE") {
|
||||
down_asr_model = down_asr_path + "/model.onnx";
|
||||
}
|
||||
|
||||
if (access(down_asr_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_asr_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[ONLINE_MODEL_DIR] = down_asr_path;
|
||||
LOG(INFO) << "Set " << ONLINE_MODEL_DIR << " : " << model_path[ONLINE_MODEL_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "ASR online model is not set, use default.";
|
||||
}
|
||||
|
||||
if (punc_dir.isSet() && !s_punc_path.empty()) {
|
||||
std::string python_cmd_punc;
|
||||
std::string down_punc_path;
|
||||
std::string down_punc_model;
|
||||
|
||||
if (access(s_punc_path.c_str(), F_OK) == 0) {
|
||||
// local
|
||||
python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
|
||||
" --export-dir ./ " + " --model_revision " +
|
||||
model_path["punc-revision"];
|
||||
down_punc_path = s_punc_path;
|
||||
} else {
|
||||
// modelscope
|
||||
LOG(INFO) << "Download model: " << s_punc_path
|
||||
<< " from modelscope : "; python_cmd_punc = python_cmd + " --model-name " +
|
||||
s_punc_path +
|
||||
" --export-dir " + s_download_model_dir +
|
||||
" --model_revision " + model_path["punc-revision"];
|
||||
down_punc_path =
|
||||
s_download_model_dir +
|
||||
"/" + s_punc_path;
|
||||
}
|
||||
|
||||
int ret = system(python_cmd_punc.c_str());
|
||||
if (ret != 0) {
|
||||
LOG(INFO) << "Failed to download model from modelscope. If you set local punc model path, you can ignore the errors.";
|
||||
}
|
||||
down_punc_model = down_punc_path + "/model_quant.onnx";
|
||||
if (s_punc_quant == "false" || s_punc_quant == "False" ||
|
||||
s_punc_quant == "FALSE") {
|
||||
down_punc_model = down_punc_path + "/model.onnx";
|
||||
}
|
||||
|
||||
if (access(down_punc_model.c_str(), F_OK) != 0) {
|
||||
LOG(ERROR) << down_punc_model << " do not exists.";
|
||||
exit(-1);
|
||||
} else {
|
||||
model_path[PUNC_DIR] = down_punc_path;
|
||||
LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "PUNC model is not set, use default.";
|
||||
}
|
||||
|
||||
} catch (std::exception const& e) {
|
||||
LOG(ERROR) << "Error: " << e.what();
|
||||
}
|
||||
|
||||
std::string s_listen_ip = listen_ip.getValue();
|
||||
int s_port = port.getValue();
|
||||
int s_io_thread_num = io_thread_num.getValue();
|
||||
int s_decoder_thread_num = decoder_thread_num.getValue();
|
||||
|
||||
int s_model_thread_num = model_thread_num.getValue();
|
||||
|
||||
asio::io_context io_decoder; // context for decoding
|
||||
asio::io_context io_server; // context for server
|
||||
|
||||
std::vector<std::thread> decoder_threads;
|
||||
|
||||
std::string s_certfile = certfile.getValue();
|
||||
std::string s_keyfile = keyfile.getValue();
|
||||
|
||||
bool is_ssl = false;
|
||||
if (!s_certfile.empty()) {
|
||||
is_ssl = true;
|
||||
}
|
||||
|
||||
auto conn_guard = asio::make_work_guard(
|
||||
io_decoder); // make sure threads can wait in the queue
|
||||
auto server_guard = asio::make_work_guard(
|
||||
io_server); // make sure threads can wait in the queue
|
||||
// create threads pool
|
||||
for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
|
||||
decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
|
||||
}
|
||||
|
||||
server server_; // server for websocket
|
||||
wss_server wss_server_;
|
||||
if (is_ssl) {
|
||||
wss_server_.init_asio(&io_server); // init asio
|
||||
wss_server_.set_reuse_addr(
|
||||
true); // reuse address as we create multiple threads
|
||||
|
||||
// list on port for accept
|
||||
wss_server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
|
||||
WebSocketServer websocket_srv(
|
||||
io_decoder, is_ssl, nullptr, &wss_server_, s_certfile,
|
||||
s_keyfile); // websocket server for asr engine
|
||||
websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
|
||||
|
||||
} else {
|
||||
server_.init_asio(&io_server); // init asio
|
||||
server_.set_reuse_addr(
|
||||
true); // reuse address as we create multiple threads
|
||||
|
||||
// list on port for accept
|
||||
server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
|
||||
WebSocketServer websocket_srv(
|
||||
io_decoder, is_ssl, &server_, nullptr, s_certfile,
|
||||
s_keyfile); // websocket server for asr engine
|
||||
websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
|
||||
}
|
||||
|
||||
std::cout << "asr model init finished. listen on port:" << s_port
|
||||
<< std::endl;
|
||||
|
||||
// Start the ASIO network io_service run loop
|
||||
std::vector<std::thread> ts;
|
||||
// create threads for io network
|
||||
for (size_t i = 0; i < s_io_thread_num; i++) {
|
||||
ts.emplace_back([&io_server]() { io_server.run(); });
|
||||
}
|
||||
// wait for theads
|
||||
for (size_t i = 0; i < s_io_thread_num; i++) {
|
||||
ts[i].join();
|
||||
}
|
||||
|
||||
// wait for theads
|
||||
for (auto& t : decoder_threads) {
|
||||
t.join();
|
||||
}
|
||||
|
||||
} catch (std::exception const& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@ -79,7 +79,7 @@ int main(int argc, char* argv[]) {
|
||||
TCLAP::ValueArg<int> decoder_thread_num(
|
||||
"", "decoder-thread-num", "decoder thread num", false, 8, "int");
|
||||
TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
|
||||
"model thread num", false, 1, "int");
|
||||
"model thread num", false, 4, "int");
|
||||
|
||||
TCLAP::ValueArg<std::string> certfile("", "certfile",
|
||||
"default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.",
|
||||
|
||||
@ -116,6 +116,18 @@ Export Detailed Introduction([docs](https://github.com/alibaba-damo-academy/Fu
|
||||
--punc-dir ./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
|
||||
```
|
||||
|
||||
##### Start the 2pass Service
|
||||
```shell
|
||||
./funasr-wss-server-2pass \
|
||||
--download-model-dir /workspace/models \
|
||||
--offline-model-dir ./exportdamo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
|
||||
--vad-dir ./exportdamo/speech_fsmn_vad_zh-cn-16k-common-onnx \
|
||||
--punc-dir ./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
|
||||
--online-model-dir ./exportdamo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online \
|
||||
--quantize false
|
||||
```
|
||||
|
||||
|
||||
### Client Usage
|
||||
|
||||
|
||||
|
||||
369
funasr/runtime/websocket/websocket-server-2pass.cpp
Normal file
369
funasr/runtime/websocket/websocket-server-2pass.cpp
Normal file
@ -0,0 +1,369 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2022-2023 by zhaomingwork */
|
||||
|
||||
// websocket server for asr engine
|
||||
// take some ideas from https://github.com/k2-fsa/sherpa-onnx
|
||||
// online-websocket-server-impl.cc, thanks. The websocket server has two threads
|
||||
// pools, one for handle network data and one for asr decoder.
|
||||
// now only support offline engine.
|
||||
|
||||
#include "websocket-server-2pass.h"
|
||||
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
context_ptr WebSocketServer::on_tls_init(tls_mode mode,
|
||||
websocketpp::connection_hdl hdl,
|
||||
std::string& s_certfile,
|
||||
std::string& s_keyfile) {
|
||||
namespace asio = websocketpp::lib::asio;
|
||||
|
||||
LOG(INFO) << "on_tls_init called with hdl: " << hdl.lock().get();
|
||||
LOG(INFO) << "using TLS mode: "
|
||||
<< (mode == MOZILLA_MODERN ? "Mozilla Modern"
|
||||
: "Mozilla Intermediate");
|
||||
|
||||
context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
|
||||
asio::ssl::context::sslv23);
|
||||
|
||||
try {
|
||||
if (mode == MOZILLA_MODERN) {
|
||||
// Modern disables TLSv1
|
||||
ctx->set_options(
|
||||
asio::ssl::context::default_workarounds |
|
||||
asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
|
||||
asio::ssl::context::no_tlsv1 | asio::ssl::context::single_dh_use);
|
||||
} else {
|
||||
ctx->set_options(asio::ssl::context::default_workarounds |
|
||||
asio::ssl::context::no_sslv2 |
|
||||
asio::ssl::context::no_sslv3 |
|
||||
asio::ssl::context::single_dh_use);
|
||||
}
|
||||
|
||||
ctx->use_certificate_chain_file(s_certfile);
|
||||
ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
|
||||
|
||||
} catch (std::exception& e) {
|
||||
LOG(INFO) << "Exception: " << e.what();
|
||||
}
|
||||
return ctx;
|
||||
}
|
||||
|
||||
nlohmann::json handle_result(FUNASR_RESULT result, std::string& online_res,
|
||||
std::string& tpass_res, nlohmann::json msg) {
|
||||
|
||||
websocketpp::lib::error_code ec;
|
||||
nlohmann::json jsonresult;
|
||||
jsonresult["text"]="";
|
||||
|
||||
std::string tmp_online_msg = FunASRGetResult(result, 0);
|
||||
online_res += tmp_online_msg;
|
||||
if (tmp_online_msg != "") {
|
||||
LOG(INFO) << "online_res :" << tmp_online_msg;
|
||||
jsonresult["text"] = tmp_online_msg;
|
||||
jsonresult["mode"] = "2pass-online";
|
||||
}
|
||||
std::string tmp_tpass_msg = FunASRGetTpassResult(result, 0);
|
||||
tpass_res += tmp_tpass_msg;
|
||||
if (tmp_tpass_msg != "") {
|
||||
LOG(INFO) << "offline results : " << tmp_tpass_msg;
|
||||
jsonresult["text"] = tmp_tpass_msg;
|
||||
jsonresult["mode"] = "2pass-offline";
|
||||
}
|
||||
|
||||
if (msg.contains("wav_name")) {
|
||||
jsonresult["wav_name"] = msg["wav_name"];
|
||||
}
|
||||
|
||||
return jsonresult;
|
||||
}
|
||||
// feed buffer to asr engine for decoder
|
||||
void WebSocketServer::do_decoder(
|
||||
std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
|
||||
nlohmann::json& msg, std::vector<std::vector<std::string>>& punc_cache,
|
||||
websocketpp::lib::mutex& thread_lock, bool& is_final,
|
||||
FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
|
||||
std::string& tpass_res) {
|
||||
|
||||
// lock for each connection
|
||||
scoped_lock guard(thread_lock);
|
||||
FUNASR_RESULT Result = nullptr;
|
||||
int asr_mode_ = 2;
|
||||
if (msg.contains("mode")) {
|
||||
std::string modeltype = msg["mode"];
|
||||
if (modeltype == "offline") {
|
||||
asr_mode_ = 0;
|
||||
} else if (modeltype == "online") {
|
||||
asr_mode_ = 1;
|
||||
} else if (modeltype == "2pass") {
|
||||
asr_mode_ = 2;
|
||||
}
|
||||
} else {
|
||||
// default value
|
||||
msg["mode"] = "2pass";
|
||||
asr_mode_ = 2;
|
||||
}
|
||||
|
||||
try {
|
||||
// loop to send chunk_size 800*2 data to asr engine. TODO: chunk_size need get from client
|
||||
while (buffer.size() >= 800 * 2) {
|
||||
std::vector<char> subvector = {buffer.begin(),
|
||||
buffer.begin() + 800 * 2};
|
||||
buffer.erase(buffer.begin(), buffer.begin() + 800 * 2);
|
||||
|
||||
try{
|
||||
Result =
|
||||
FunTpassInferBuffer(tpass_handle, tpass_online_handle,
|
||||
subvector.data(), subvector.size(), punc_cache,
|
||||
false, msg["audio_fs"], msg["wav_format"], (ASR_TYPE)asr_mode_);
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
}
|
||||
if (Result) {
|
||||
websocketpp::lib::error_code ec;
|
||||
nlohmann::json jsonresult =
|
||||
handle_result(Result, online_res, tpass_res, msg["wav_name"]);
|
||||
jsonresult["is_final"] = false;
|
||||
if(jsonresult["text"] != "") {
|
||||
if (is_ssl) {
|
||||
wss_server_->send(hdl, jsonresult.dump(),
|
||||
websocketpp::frame::opcode::text, ec);
|
||||
} else {
|
||||
server_->send(hdl, jsonresult.dump(),
|
||||
websocketpp::frame::opcode::text, ec);
|
||||
}
|
||||
}
|
||||
FunASRFreeResult(Result);
|
||||
}
|
||||
|
||||
}
|
||||
if(is_final){
|
||||
|
||||
try{
|
||||
Result = FunTpassInferBuffer(tpass_handle, tpass_online_handle,
|
||||
buffer.data(), buffer.size(), punc_cache,
|
||||
is_final, msg["audio_fs"], msg["wav_format"], (ASR_TYPE)asr_mode_);
|
||||
}catch (std::exception const &e)
|
||||
{
|
||||
LOG(ERROR)<<e.what();
|
||||
}
|
||||
for(auto &vec:punc_cache){
|
||||
vec.clear();
|
||||
}
|
||||
if (Result) {
|
||||
websocketpp::lib::error_code ec;
|
||||
nlohmann::json jsonresult =
|
||||
handle_result(Result, online_res, tpass_res, msg["wav_name"]);
|
||||
jsonresult["is_final"] = true;
|
||||
if (is_ssl) {
|
||||
wss_server_->send(hdl, jsonresult.dump(),
|
||||
websocketpp::frame::opcode::text, ec);
|
||||
} else {
|
||||
server_->send(hdl, jsonresult.dump(),
|
||||
websocketpp::frame::opcode::text, ec);
|
||||
}
|
||||
FunASRFreeResult(Result);
|
||||
}
|
||||
}
|
||||
|
||||
} catch (std::exception const& e) {
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
|
||||
scoped_lock guard(m_lock); // for threads safty
|
||||
check_and_clean_connection(); // remove closed connection
|
||||
|
||||
std::shared_ptr<FUNASR_MESSAGE> data_msg =
|
||||
std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for new
|
||||
// connection
|
||||
data_msg->samples = std::make_shared<std::vector<char>>();
|
||||
data_msg->thread_lock = new websocketpp::lib::mutex();
|
||||
|
||||
data_msg->msg = nlohmann::json::parse("{}");
|
||||
data_msg->msg["wav_format"] = "pcm";
|
||||
data_msg->msg["audio_fs"] = 16000;
|
||||
data_msg->punc_cache =
|
||||
std::make_shared<std::vector<std::vector<std::string>>>(2);
|
||||
// std::vector<int> chunk_size = {5, 10, 5}; //TODO, need get from client
|
||||
// FUNASR_HANDLE tpass_online_handle =
|
||||
// FunTpassOnlineInit(tpass_handle, chunk_size);
|
||||
// data_msg->tpass_online_handle = tpass_online_handle;
|
||||
data_map.emplace(hdl, data_msg);
|
||||
LOG(INFO) << "on_open, active connections: " << data_map.size();
|
||||
|
||||
}
|
||||
|
||||
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
|
||||
scoped_lock guard(m_lock);
|
||||
std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
|
||||
auto it_data = data_map.find(hdl);
|
||||
if (it_data != data_map.end()) {
|
||||
data_msg = it_data->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
}
|
||||
scoped_lock guard_decoder(*(data_msg->thread_lock)); //wait for do_decoder finished and avoid access freed tpass_online_handle
|
||||
FunTpassOnlineUninit(data_msg->tpass_online_handle);
|
||||
data_map.erase(hdl); // remove data vector when connection is closed
|
||||
LOG(INFO) << "on_close, active connections: "<< data_map.size();
|
||||
}
|
||||
|
||||
// remove closed connection
|
||||
void WebSocketServer::check_and_clean_connection() {
|
||||
std::vector<websocketpp::connection_hdl> to_remove; // remove list
|
||||
auto iter = data_map.begin();
|
||||
while (iter != data_map.end()) { // loop to find closed connection
|
||||
websocketpp::connection_hdl hdl = iter->first;
|
||||
|
||||
if (is_ssl) {
|
||||
wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
|
||||
if (con->get_state() != 1) { // session::state::open ==1
|
||||
to_remove.push_back(hdl);
|
||||
}
|
||||
} else {
|
||||
server::connection_ptr con = server_->get_con_from_hdl(hdl);
|
||||
if (con->get_state() != 1) { // session::state::open ==1
|
||||
to_remove.push_back(hdl);
|
||||
}
|
||||
}
|
||||
|
||||
iter++;
|
||||
}
|
||||
for (auto hdl : to_remove) {
|
||||
data_map.erase(hdl);
|
||||
LOG(INFO) << "remove one connection ";
|
||||
}
|
||||
}
|
||||
void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
||||
message_ptr msg) {
|
||||
unique_lock lock(m_lock);
|
||||
// find the sample data vector according to one connection
|
||||
|
||||
std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
|
||||
|
||||
auto it_data = data_map.find(hdl);
|
||||
if (it_data != data_map.end()) {
|
||||
msg_data = it_data->second;
|
||||
}
|
||||
|
||||
std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
|
||||
std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache_p =
|
||||
msg_data->punc_cache;
|
||||
websocketpp::lib::mutex* thread_lock_p = msg_data->thread_lock;
|
||||
|
||||
lock.unlock();
|
||||
|
||||
if (sample_data_p == nullptr) {
|
||||
LOG(INFO) << "error when fetch sample data vector";
|
||||
return;
|
||||
}
|
||||
|
||||
const std::string& payload = msg->get_payload(); // get msg type
|
||||
|
||||
switch (msg->get_opcode()) {
|
||||
case websocketpp::frame::opcode::text: {
|
||||
nlohmann::json jsonresult = nlohmann::json::parse(payload);
|
||||
|
||||
if (jsonresult.contains("wav_name")) {
|
||||
msg_data->msg["wav_name"] = jsonresult["wav_name"];
|
||||
}
|
||||
if (jsonresult.contains("mode")) {
|
||||
msg_data->msg["mode"] = jsonresult["mode"];
|
||||
}
|
||||
if (jsonresult.contains("wav_format")) {
|
||||
msg_data->msg["wav_format"] = jsonresult["wav_format"];
|
||||
}
|
||||
if (jsonresult.contains("audio_fs")) {
|
||||
msg_data->msg["audio_fs"] = jsonresult["audio_fs"];
|
||||
}
|
||||
if (jsonresult.contains("chunk_size")){
|
||||
if(msg_data->tpass_online_handle == NULL){
|
||||
std::vector<int> chunk_size_vec = jsonresult["chunk_size"].get<std::vector<int>>();
|
||||
FUNASR_HANDLE tpass_online_handle =
|
||||
FunTpassOnlineInit(tpass_handle, chunk_size_vec);
|
||||
msg_data->tpass_online_handle = tpass_online_handle;
|
||||
}
|
||||
}
|
||||
LOG(INFO) << "jsonresult=" << jsonresult << ", msg_data->msg="
|
||||
<< msg_data->msg;
|
||||
if (jsonresult["is_speaking"] == false ||
|
||||
jsonresult["is_finished"] == true) {
|
||||
LOG(INFO) << "client done";
|
||||
|
||||
// if it is in final message, post the sample_data to decode
|
||||
asio::post(
|
||||
io_decoder_,
|
||||
std::bind(&WebSocketServer::do_decoder, this,
|
||||
std::move(*(sample_data_p.get())), std::move(hdl),
|
||||
std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
|
||||
std::ref(*thread_lock_p), std::move(true),
|
||||
std::ref(msg_data->tpass_online_handle),
|
||||
std::ref(msg_data->online_res),
|
||||
std::ref(msg_data->tpass_res)));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case websocketpp::frame::opcode::binary: {
|
||||
// recived binary data
|
||||
const auto* pcm_data = static_cast<const char*>(payload.data());
|
||||
int32_t num_samples = payload.size();
|
||||
|
||||
if (isonline) {
|
||||
sample_data_p->insert(sample_data_p->end(), pcm_data,
|
||||
pcm_data + num_samples);
|
||||
int setpsize = 800 * 2; // TODO, need get from client
|
||||
// if sample_data size > setpsize, we post data to decode
|
||||
if (sample_data_p->size() > setpsize) {
|
||||
int chunksize = floor(sample_data_p->size() / setpsize);
|
||||
// make sure the subvector size is an integer multiple of setpsize
|
||||
std::vector<char> subvector = {
|
||||
sample_data_p->begin(),
|
||||
sample_data_p->begin() + chunksize * setpsize};
|
||||
// keep remain in sample_data
|
||||
sample_data_p->erase(sample_data_p->begin(),
|
||||
sample_data_p->begin() + chunksize * setpsize);
|
||||
// post to decode
|
||||
asio::post(io_decoder_,
|
||||
std::bind(&WebSocketServer::do_decoder, this,
|
||||
std::move(subvector), std::move(hdl),
|
||||
std::ref(msg_data->msg),
|
||||
std::ref(*(punc_cache_p.get())),
|
||||
std::ref(*thread_lock_p), std::move(false),
|
||||
std::ref(msg_data->tpass_online_handle),
|
||||
std::ref(msg_data->online_res),
|
||||
std::ref(msg_data->tpass_res)));
|
||||
}
|
||||
} else {
|
||||
sample_data_p->insert(sample_data_p->end(), pcm_data,
|
||||
pcm_data + num_samples);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// init asr model
|
||||
void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
|
||||
int thread_num) {
|
||||
try {
|
||||
tpass_handle = FunTpassInit(model_path, thread_num);
|
||||
if (!tpass_handle) {
|
||||
LOG(ERROR) << "FunTpassInit init failed";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
} catch (const std::exception& e) {
|
||||
LOG(INFO) << e.what();
|
||||
}
|
||||
}
|
||||
148
funasr/runtime/websocket/websocket-server-2pass.h
Normal file
148
funasr/runtime/websocket/websocket-server-2pass.h
Normal file
@ -0,0 +1,148 @@
|
||||
/**
|
||||
* Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
|
||||
* Reserved. MIT License (https://opensource.org/licenses/MIT)
|
||||
*/
|
||||
/* 2022-2023 by zhaomingwork */
|
||||
|
||||
// websocket server for asr engine
|
||||
// take some ideas from https://github.com/k2-fsa/sherpa-onnx
|
||||
// online-websocket-server-impl.cc, thanks. The websocket server has two threads
|
||||
// pools, one for handle network data and one for asr decoder.
|
||||
// now only support offline engine.
|
||||
|
||||
#ifndef WEBSOCKET_SERVER_H_
|
||||
#define WEBSOCKET_SERVER_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#define ASIO_STANDALONE 1 // not boost
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <websocketpp/common/thread.hpp>
|
||||
#include <websocketpp/config/asio.hpp>
|
||||
#include <websocketpp/server.hpp>
|
||||
|
||||
#include "asio.hpp"
|
||||
#include "com-define.h"
|
||||
#include "funasrruntime.h"
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "tclap/CmdLine.h"
|
||||
typedef websocketpp::server<websocketpp::config::asio> server;
|
||||
typedef websocketpp::server<websocketpp::config::asio_tls> wss_server;
|
||||
typedef server::message_ptr message_ptr;
|
||||
using websocketpp::lib::bind;
|
||||
using websocketpp::lib::placeholders::_1;
|
||||
using websocketpp::lib::placeholders::_2;
|
||||
|
||||
typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
|
||||
typedef websocketpp::lib::unique_lock<websocketpp::lib::mutex> unique_lock;
|
||||
typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
|
||||
context_ptr;
|
||||
|
||||
typedef struct {
|
||||
std::string msg;
|
||||
float snippet_time;
|
||||
} FUNASR_RECOG_RESULT;
|
||||
|
||||
typedef struct {
|
||||
nlohmann::json msg;
|
||||
std::shared_ptr<std::vector<char>> samples;
|
||||
std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache;
|
||||
websocketpp::lib::mutex* thread_lock; // lock for each connection
|
||||
FUNASR_HANDLE tpass_online_handle=NULL;
|
||||
std::string online_res = "";
|
||||
std::string tpass_res = "";
|
||||
|
||||
} FUNASR_MESSAGE;
|
||||
|
||||
// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
|
||||
// the TLS modes. The code below demonstrates how to implement both the modern
|
||||
enum tls_mode { MOZILLA_INTERMEDIATE = 1, MOZILLA_MODERN = 2 };
|
||||
class WebSocketServer {
|
||||
public:
|
||||
WebSocketServer(asio::io_context& io_decoder, bool is_ssl, server* server,
|
||||
wss_server* wss_server, std::string& s_certfile,
|
||||
std::string& s_keyfile)
|
||||
: io_decoder_(io_decoder),
|
||||
is_ssl(is_ssl),
|
||||
server_(server),
|
||||
wss_server_(wss_server) {
|
||||
if (is_ssl) {
|
||||
std::cout << "certfile path is " << s_certfile << std::endl;
|
||||
wss_server->set_tls_init_handler(
|
||||
bind<context_ptr>(&WebSocketServer::on_tls_init, this,
|
||||
MOZILLA_INTERMEDIATE, ::_1, s_certfile, s_keyfile));
|
||||
wss_server_->set_message_handler(
|
||||
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
|
||||
on_message(hdl, msg);
|
||||
});
|
||||
// set open handle
|
||||
wss_server_->set_open_handler(
|
||||
[this](websocketpp::connection_hdl hdl) { on_open(hdl); });
|
||||
// set close handle
|
||||
wss_server_->set_close_handler(
|
||||
[this](websocketpp::connection_hdl hdl) { on_close(hdl); });
|
||||
// begin accept
|
||||
wss_server_->start_accept();
|
||||
// not print log
|
||||
wss_server_->clear_access_channels(websocketpp::log::alevel::all);
|
||||
|
||||
} else {
|
||||
// set message handle
|
||||
server_->set_message_handler(
|
||||
[this](websocketpp::connection_hdl hdl, message_ptr msg) {
|
||||
on_message(hdl, msg);
|
||||
});
|
||||
// set open handle
|
||||
server_->set_open_handler(
|
||||
[this](websocketpp::connection_hdl hdl) { on_open(hdl); });
|
||||
// set close handle
|
||||
server_->set_close_handler(
|
||||
[this](websocketpp::connection_hdl hdl) { on_close(hdl); });
|
||||
// begin accept
|
||||
server_->start_accept();
|
||||
// not print log
|
||||
server_->clear_access_channels(websocketpp::log::alevel::all);
|
||||
}
|
||||
}
|
||||
void do_decoder(std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
|
||||
nlohmann::json& msg,
|
||||
std::vector<std::vector<std::string>>& punc_cache,
|
||||
websocketpp::lib::mutex& thread_lock, bool& is_final,
|
||||
FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
|
||||
std::string& tpass_res);
|
||||
|
||||
void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
|
||||
void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
|
||||
void on_open(websocketpp::connection_hdl hdl);
|
||||
void on_close(websocketpp::connection_hdl hdl);
|
||||
context_ptr on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl,
|
||||
std::string& s_certfile, std::string& s_keyfile);
|
||||
|
||||
private:
|
||||
void check_and_clean_connection();
|
||||
asio::io_context& io_decoder_; // threads for asr decoder
|
||||
// std::ofstream fout;
|
||||
// FUNASR_HANDLE asr_handle; // asr engine handle
|
||||
FUNASR_HANDLE tpass_handle=NULL;
|
||||
bool isonline = true; // online or offline engine, now only support offline
|
||||
bool is_ssl = true;
|
||||
server* server_; // websocket server
|
||||
wss_server* wss_server_; // websocket server
|
||||
|
||||
// use map to keep the received samples data from one connection in offline
|
||||
// engine. if for online engline, a data struct is needed(TODO)
|
||||
|
||||
std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
|
||||
std::owner_less<websocketpp::connection_hdl>>
|
||||
data_map;
|
||||
websocketpp::lib::mutex m_lock; // mutex for sample_map
|
||||
};
|
||||
|
||||
#endif // WEBSOCKET_SERVER_H_
|
||||
Loading…
Reference in New Issue
Block a user