From 585172c9ec754330f4865f4a042bdb0c8bb54c08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=8C=97=E5=BF=B5?= Date: Tue, 21 Feb 2023 11:18:02 +0800 Subject: [PATCH] support grpc+onnxruntime --- .../runtime/python/grpc/grpc_main_server.py | 16 +++++++-- funasr/runtime/python/grpc/grpc_server.py | 36 +++++++++++++------ funasr/runtime/python/grpc/paraformer_onnx.py | 1 + funasr/runtime/python/grpc/utils | 1 + .../rapid_paraformer/utils/frontend.py | 15 ++++++++ 5 files changed, 57 insertions(+), 12 deletions(-) create mode 120000 funasr/runtime/python/grpc/paraformer_onnx.py create mode 120000 funasr/runtime/python/grpc/utils diff --git a/funasr/runtime/python/grpc/grpc_main_server.py b/funasr/runtime/python/grpc/grpc_main_server.py index f3b2348ee..e862ac495 100644 --- a/funasr/runtime/python/grpc/grpc_main_server.py +++ b/funasr/runtime/python/grpc/grpc_main_server.py @@ -9,7 +9,8 @@ def serve(args): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10), # interceptors=(AuthInterceptor('Bearer mysecrettoken'),) ) - paraformer_pb2_grpc.add_ASRServicer_to_server(ASRServicer(args.user_allowed, args.model, args.sample_rate), server) + paraformer_pb2_grpc.add_ASRServicer_to_server( + ASRServicer(args.user_allowed, args.model, args.sample_rate, args.backend, args.onnx_dir), server) port = "[::]:" + str(args.port) server.add_insecure_port(port) server.start() @@ -37,7 +38,18 @@ if __name__ == '__main__': parser.add_argument("--sample_rate", type=int, default=16000, - help="audio sample_rate from client") + help="audio sample_rate from client") + + parser.add_argument("--backend", + type=str, + default="pipeline", + choices=("pipeline", "onnxruntime"), + help="backend, optional modelscope pipeline or onnxruntime") + + parser.add_argument("--onnx_dir", + type=str, + default="/nfs/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + help="onnx model dir") diff --git a/funasr/runtime/python/grpc/grpc_server.py b/funasr/runtime/python/grpc/grpc_server.py index 19b735468..2d03f9dcf 100644 --- a/funasr/runtime/python/grpc/grpc_server.py +++ b/funasr/runtime/python/grpc/grpc_server.py @@ -3,20 +3,32 @@ import grpc import json import time -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks import paraformer_pb2_grpc from paraformer_pb2 import Response +from utils.frontend import load_bytes class ASRServicer(paraformer_pb2_grpc.ASRServicer): - def __init__(self, user_allowed, model, sample_rate): + def __init__(self, user_allowed, model, sample_rate, backend, onnx_dir): print("ASRServicer init") + self.backend = backend self.init_flag = 0 self.client_buffers = {} self.client_transcription = {} self.auth_user = user_allowed.split("|") - self.inference_16k_pipeline = pipeline(task=Tasks.auto_speech_recognition, model=model) + if self.backend == "pipeline": + try: + from modelscope.pipelines import pipeline + from modelscope.utils.constant import Tasks + except ImportError: + raise ImportError(f"Please install modelscope") + self.inference_16k_pipeline = pipeline(task=Tasks.auto_speech_recognition, model=model) + elif self.backend == "onnxruntime": + try: + from paraformer_onnx import Paraformer + except ImportError: + raise ImportError(f"Please install onnxruntime requirements, reference https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/onnxruntime/rapid_paraformer") + self.inference_16k_pipeline = Paraformer(model_dir=onnx_dir) self.sample_rate = sample_rate def clear_states(self, user): @@ -90,12 +102,16 @@ class ASRServicer(paraformer_pb2_grpc.ASRServicer): result["text"] = "" print ("user: %s , delay(ms): %s, info: %s " % (req.user, delay_str, "waiting_for_more_voice")) yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language) - else: - asr_result = self.inference_16k_pipeline(audio_in=tmp_data, audio_fs = self.sample_rate) - if "text" in asr_result: - asr_result = asr_result['text'] - else: - asr_result = "" + else: + if self.backend == "pipeline": + asr_result = self.inference_16k_pipeline(audio_in=tmp_data, audio_fs = self.sample_rate) + if "text" in asr_result: + asr_result = asr_result['text'] + else: + asr_result = "" + elif self.backend == "onnxruntime": + array = load_bytes(tmp_data) + asr_result = self.inference_16k_pipeline(array)[0] end_time = int(round(time.time() * 1000)) delay_str = str(end_time - begin_time) print ("user: %s , delay(ms): %s, text: %s " % (req.user, delay_str, asr_result)) diff --git a/funasr/runtime/python/grpc/paraformer_onnx.py b/funasr/runtime/python/grpc/paraformer_onnx.py new file mode 120000 index 000000000..a05b2235f --- /dev/null +++ b/funasr/runtime/python/grpc/paraformer_onnx.py @@ -0,0 +1 @@ +../onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py \ No newline at end of file diff --git a/funasr/runtime/python/grpc/utils b/funasr/runtime/python/grpc/utils new file mode 120000 index 000000000..831d965b8 --- /dev/null +++ b/funasr/runtime/python/grpc/utils @@ -0,0 +1 @@ +../onnxruntime/paraformer/rapid_paraformer/utils \ No newline at end of file diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/frontend.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/frontend.py index eb8a7c869..a3ca729d2 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/frontend.py +++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/frontend.py @@ -134,3 +134,18 @@ class WavFrontend(): vars = np.array(vars_list).astype(np.float64) cmvn = np.array([means, vars]) return cmvn + +def load_bytes(input): + middle_data = np.frombuffer(input, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in 'iu': + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype('float32') + if dtype.kind != 'f': + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) + return array \ No newline at end of file