From c2d1a9560017d8ae6d775751732552924f2a4a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 31 Mar 2023 17:18:44 +0800 Subject: [PATCH 01/15] export --- .../python/websocket/ASR_server_2pass.py | 252 ++++++++++++++++++ funasr/runtime/python/websocket/README.md | 7 +- 2 files changed, 258 insertions(+), 1 deletion(-) create mode 100644 funasr/runtime/python/websocket/ASR_server_2pass.py diff --git a/funasr/runtime/python/websocket/ASR_server_2pass.py b/funasr/runtime/python/websocket/ASR_server_2pass.py new file mode 100644 index 000000000..55dc2e299 --- /dev/null +++ b/funasr/runtime/python/websocket/ASR_server_2pass.py @@ -0,0 +1,252 @@ +import asyncio +import json +import websockets +import time +from queue import Queue +import threading +import argparse + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +import logging +import tracemalloc +import numpy as np + +tracemalloc.start() + +logger = get_logger(log_level=logging.CRITICAL) +logger.setLevel(logging.CRITICAL) + + +websocket_users = set() #维护客户端列表 + +parser = argparse.ArgumentParser() +parser.add_argument("--host", + type=str, + default="0.0.0.0", + required=False, + help="host ip, localhost, 0.0.0.0") +parser.add_argument("--port", + type=int, + default=10095, + required=False, + help="grpc server port") +parser.add_argument("--asr_model", + type=str, + default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + help="model from modelscope") +parser.add_argument("--vad_model", + type=str, + default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + help="model from modelscope") + +parser.add_argument("--punc_model", + type=str, + default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", + help="model from modelscope") +parser.add_argument("--ngpu", + type=int, + default=1, + help="0 for cpu, 1 for gpu") + +args = parser.parse_args() + +print("model loading") + +def load_bytes(input): + middle_data = np.frombuffer(input, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in 'iu': + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype('float32') + if dtype.kind != 'f': + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) + return array + +# vad +inference_pipeline_vad = pipeline( + task=Tasks.voice_activity_detection, + model=args.vad_model, + model_revision=None, + output_dir=None, + batch_size=1, + mode='online', + ngpu=args.ngpu, +) +# param_dict_vad = {'in_cache': dict(), "is_final": False} + +# asr +param_dict_asr = {} +# param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开 +inference_pipeline_asr = pipeline( + task=Tasks.auto_speech_recognition, + model=args.asr_model, + param_dict=param_dict_asr, + ngpu=args.ngpu, +) +if args.punc_model != "": + # param_dict_punc = {'cache': list()} + inference_pipeline_punc = pipeline( + task=Tasks.punctuation, + model=args.punc_model, + model_revision=None, + ngpu=args.ngpu, + ) +else: + inference_pipeline_punc = None + + +inference_pipeline_asr_online = pipeline( + task=Tasks.auto_speech_recognition, + model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online', + model_revision='v1.0.2') + + +print("model loaded") + + + +async def ws_serve(websocket, path): + #speek = Queue() + frames = [] # 存储所有的帧数据 + frames_online = [] # 存储所有的帧数据 + buffer = [] # 存储缓存中的帧数据(最多两个片段) + RECORD_NUM = 0 + global websocket_users + speech_start, speech_end = False, False + # 调用asr函数 + websocket.param_dict_vad = {'in_cache': dict(), "is_final": False} + websocket.param_dict_punc = {'cache': list()} + websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包 + websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端 + websocket_users.add(websocket) + ss = threading.Thread(target=asr, args=(websocket,)) + ss.start() + + websocket.param_dict_asr_online = {"cache": dict(), "is_final": False} + websocket.speek_online = Queue() # websocket 添加进队列对象 让asr读取语音数据包 + ss_online = threading.Thread(target=asr_online, args=(websocket,)) + ss_online.start() + + try: + async for message in websocket: + #voices.put(message) + #print("put") + #await websocket.send("123") + buffer.append(message) + if len(buffer) > 2: + buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 + + if speech_start: + frames.append(message) + frames_online.append(message) + RECORD_NUM += 1 + if RECORD_NUM % 6 == 0: + audio_in = b"".join(frames_online) + websocket.speek_online.put(audio_in) + frames_online = [] + + speech_start_i, speech_end_i = vad(message, websocket) + #print(speech_start_i, speech_end_i) + if speech_start_i: + RECORD_NUM += 1 + speech_start = speech_start_i + frames = [] + frames.extend(buffer) # 把之前2个语音数据快加入 + frames_online = [] + frames_online.append(message) + # frames_online.extend(buffer) + # RECORD_NUM += 1 + websocket.param_dict_asr_online["is_final"] = False + if speech_end_i or RECORD_NUM > 300: + speech_start = False + audio_in = b"".join(frames) + websocket.speek.put(audio_in) + frames = [] # 清空所有的帧数据 + frames_online = [] + websocket.param_dict_asr_online["is_final"] = True + buffer = [] # 清空缓存中的帧数据(最多两个片段) + RECORD_NUM = 0 + if not websocket.send_msg.empty(): + await websocket.send(websocket.send_msg.get()) + websocket.send_msg.task_done() + + + except websockets.ConnectionClosed: + print("ConnectionClosed...", websocket_users) # 链接断开 + websocket_users.remove(websocket) + except websockets.InvalidState: + print("InvalidState...") # 无效状态 + except Exception as e: + print("Exception:", e) + + +def asr(websocket): # ASR推理 + global inference_pipeline_asr + # global param_dict_punc + global websocket_users + while websocket in websocket_users: + if not websocket.speek.empty(): + audio_in = websocket.speek.get() + websocket.speek.task_done() + if len(audio_in) > 0: + rec_result = inference_pipeline_asr(audio_in=audio_in) + if inference_pipeline_punc is not None and 'text' in rec_result: + rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc) + # print(rec_result) + if "text" in rec_result: + message = json.dumps({"mode": "offline", "text": rec_result["text"]}) + websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 + + time.sleep(0.1) + + +def asr_online(websocket): # ASR推理 + global inference_pipeline_asr_online + # global param_dict_punc + global websocket_users + while websocket in websocket_users: + if not websocket.speek_online.empty(): + audio_in = websocket.speek_online.get() + websocket.speek_online.task_done() + if len(audio_in) > 0: + # print(len(audio_in)) + audio_in = load_bytes(audio_in) + # print(audio_in.shape) + rec_result = inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online) + + # print(rec_result) + if "text" in rec_result: + message = json.dumps({"mode": "online", "text": rec_result["text"]}) + websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 + + time.sleep(0.1) + +def vad(data, websocket): # VAD推理 + global inference_pipeline_vad, param_dict_vad + #print(type(data)) + # print(param_dict_vad) + segments_result = inference_pipeline_vad(audio_in=data, param_dict=websocket.param_dict_vad) + # print(segments_result) + # print(param_dict_vad) + speech_start = False + speech_end = False + + if len(segments_result) == 0 or len(segments_result["text"]) > 1: + return speech_start, speech_end + if segments_result["text"][0][0] != -1: + speech_start = True + if segments_result["text"][0][1] != -1: + speech_end = True + return speech_start, speech_end + + +start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() \ No newline at end of file diff --git a/funasr/runtime/python/websocket/README.md b/funasr/runtime/python/websocket/README.md index 353cfa6ac..d8e7bf19e 100644 --- a/funasr/runtime/python/websocket/README.md +++ b/funasr/runtime/python/websocket/README.md @@ -25,6 +25,11 @@ Start server ```shell python ASR_server.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" ``` +For the paraformer 2pass model + +```shell +python ASR_server_2pass.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +``` ## For the client @@ -38,7 +43,7 @@ pip install -r requirements_client.txt Start client ```shell -python ASR_client.py --host "127.0.0.1" --port 10095 --chunk_size 300 +python ASR_client.py --host "127.0.0.1" --port 10095 --chunk_size 50 ``` ## Acknowledge From 63bcaf7093c986e0c33b14212fca3705c0877864 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 31 Mar 2023 17:28:00 +0800 Subject: [PATCH 02/15] export --- funasr/runtime/python/websocket/ASR_client.py | 5 ++++- funasr/runtime/python/websocket/ASR_server.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/funasr/runtime/python/websocket/ASR_client.py b/funasr/runtime/python/websocket/ASR_client.py index fe6798127..d1fb93a93 100644 --- a/funasr/runtime/python/websocket/ASR_client.py +++ b/funasr/runtime/python/websocket/ASR_client.py @@ -6,6 +6,7 @@ import asyncio from queue import Queue # import threading import argparse +import json parser = argparse.ArgumentParser() parser.add_argument("--host", @@ -78,7 +79,9 @@ async def message(): global websocket while True: try: - print(await websocket.recv()) + meg = await websocket.recv() + meg = json.loads(meg) + print(meg) except Exception as e: print("Exception:", e) diff --git a/funasr/runtime/python/websocket/ASR_server.py b/funasr/runtime/python/websocket/ASR_server.py index 827df7b58..c717e7126 100644 --- a/funasr/runtime/python/websocket/ASR_server.py +++ b/funasr/runtime/python/websocket/ASR_server.py @@ -4,6 +4,7 @@ import time from queue import Queue import threading import argparse +import json from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks @@ -157,7 +158,8 @@ def asr(websocket): # ASR推理 rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc) # print(rec_result) if "text" in rec_result: - websocket.send_msg.put(rec_result["text"]) # 存入发送队列 直接调用send发送不了 + message = json.dumps({"mode": "offline", "text": rec_result["text"]}) + websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 time.sleep(0.1) From 24b341a7eb0ad72e021470b8f2d1ee1d0b29ea81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Sun, 23 Apr 2023 19:57:10 +0800 Subject: [PATCH 03/15] client websocket --- funasr/runtime/python/websocket/ASR_client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/funasr/runtime/python/websocket/ASR_client.py b/funasr/runtime/python/websocket/ASR_client.py index d1fb93a93..fa953288a 100644 --- a/funasr/runtime/python/websocket/ASR_client.py +++ b/funasr/runtime/python/websocket/ASR_client.py @@ -23,6 +23,10 @@ parser.add_argument("--chunk_size", type=int, default=300, help="ms") +parser.add_argument("--audio_in", + type=str, + default=None, + help="audio_in") args = parser.parse_args() From 678a6c0f7293a86fb1046cf043afec29e88fd5f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 24 Apr 2023 15:54:54 +0800 Subject: [PATCH 04/15] websocket --- funasr/runtime/python/websocket/ASR_client.py | 41 +++++++++++++++---- .../python/websocket/ASR_server_2pass.py | 2 +- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/funasr/runtime/python/websocket/ASR_client.py b/funasr/runtime/python/websocket/ASR_client.py index fa953288a..cc0e7b6e4 100644 --- a/funasr/runtime/python/websocket/ASR_client.py +++ b/funasr/runtime/python/websocket/ASR_client.py @@ -1,9 +1,8 @@ -import pyaudio + # import websocket #区别服务端这里是 websocket-client库 import time import websockets import asyncio -from queue import Queue # import threading import argparse import json @@ -30,12 +29,13 @@ parser.add_argument("--audio_in", args = parser.parse_args() +# voices = asyncio.Queue() +from queue import Queue voices = Queue() - - # 其他函数可以通过调用send(data)来发送数据,例如: -async def record(): +async def record_microphone(): + import pyaudio #print("2") global voices FORMAT = pyaudio.paInt16 @@ -59,8 +59,32 @@ async def record(): #print(voices.qsize()) await asyncio.sleep(0.01) - +# 其他函数可以通过调用send(data)来发送数据,例如: +async def record_from_scp(): + global voices + if args.audio_in.endswith(".scp"): + f_scp = open(args.audio_in) + wavs = f_scp.readlines() + else: + wavs = [args.audio_in] + for wav in wavs: + wav_splits = wav.strip().split() + wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0] + bytes = open(wav_path, "rb") + bytes = bytes.read() + + stride = int(args.chunk_size/1000*16000*2) + chunk_num = (len(bytes)-1)//stride + 1 + for i in range(chunk_num): + beg = i*stride + data_chunk = bytes[beg:beg+stride] + voices.put(data_chunk) + # print("data_chunk: ", len(data_chunk)) + # print(voices.qsize()) + + await asyncio.sleep(args.chunk_size/1000) + async def ws_send(): global voices @@ -97,7 +121,10 @@ async def ws_client(): uri = "ws://{}:{}".format(args.host, args.port) #ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接 async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None): - task = asyncio.create_task(record()) # 创建一个后台任务录音 + if args.audio_in is not None: + task = asyncio.create_task(record_from_scp()) # 创建一个后台任务录音 + else: + task = asyncio.create_task(record_microphone()) # 创建一个后台任务录音 task2 = asyncio.create_task(ws_send()) # 创建一个后台任务发送 task3 = asyncio.create_task(message()) # 创建一个后台接收消息的任务 await asyncio.gather(task, task2, task3) diff --git a/funasr/runtime/python/websocket/ASR_server_2pass.py b/funasr/runtime/python/websocket/ASR_server_2pass.py index 55dc2e299..135a3cc34 100644 --- a/funasr/runtime/python/websocket/ASR_server_2pass.py +++ b/funasr/runtime/python/websocket/ASR_server_2pass.py @@ -105,7 +105,7 @@ else: inference_pipeline_asr_online = pipeline( task=Tasks.auto_speech_recognition, model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online', - model_revision='v1.0.2') + model_revision=None) print("model loaded") From 7584bbd6f3e321cc8bc970739a7cfce29ffcc18b Mon Sep 17 00:00:00 2001 From: "haoneng.lhn" Date: Thu, 27 Apr 2023 00:21:20 +0800 Subject: [PATCH 05/15] update paraformer streaming code --- .../bin/asr_inference_paraformer_streaming.py | 393 +++++------------- funasr/models/e2e_asr_paraformer.py | 4 +- funasr/models/encoder/sanm_encoder.py | 21 +- funasr/models/predictor/cif.py | 128 +++--- funasr/modules/embedding.py | 13 +- 5 files changed, 196 insertions(+), 363 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index 821f69429..939ffe99f 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -19,7 +19,6 @@ from typing import List import numpy as np import torch -import torchaudio from typeguard import check_argument_types from funasr.fileio.datadir_writer import DatadirWriter @@ -40,11 +39,12 @@ from funasr.utils.types import str2bool from funasr.utils.types import str2triple_str from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils -from funasr.models.frontend.wav_frontend import WavFrontend -from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer +from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export + np.set_printoptions(threshold=np.inf) + class Speech2Text: """Speech2Text class @@ -89,7 +89,7 @@ class Speech2Text: ) frontend = None if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) + frontend = WavFrontendOnline(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) logging.info("asr_model: {}".format(asr_model)) logging.info("asr_train_args: {}".format(asr_train_args)) @@ -189,8 +189,7 @@ class Speech2Text: @torch.no_grad() def __call__( - self, cache: dict, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, - begin_time: int = 0, end_time: int = None, + self, cache: dict, speech: Union[torch.Tensor], speech_lengths: Union[torch.Tensor] = None ): """Inference @@ -201,38 +200,57 @@ class Speech2Text: """ assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None + results = [] + cache_en = cache["encoder"] + if speech.shape[1] < 16 * 60 and cache["is_final"]: + cache["last_chunk"] = True + feats = cache["feats"] + feats_len = torch.tensor([feats.shape[1]]) else: - feats = speech - feats_len = speech_lengths - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - feats_len = cache["encoder"]["stride"] + cache["encoder"]["pad_left"] + cache["encoder"]["pad_right"] - feats = feats[:,cache["encoder"]["start_idx"]:cache["encoder"]["start_idx"]+feats_len,:] - feats_len = torch.tensor([feats_len]) - batch = {"speech": feats, "speech_lengths": feats_len, "cache": cache} + if self.frontend is not None: + feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"]) + feats = to_device(feats, device=self.device) + feats_len = feats_len.int() + self.asr_model.frontend = None + else: + feats = speech + feats_len = speech_lengths - # a. To device + if feats.shape[1] != 0: + if cache_en["is_final"]: + if feats.shape[1] + cache_en["chunk_size"][2] < cache_en["chunk_size"][1]: + cache_en["last_chunk"] = True + else: + # first chunk + feats_chunk1 = feats[:, :cache_en["chunk_size"][1], :] + feats_len = torch.tensor([feats_chunk1.shape[1]]) + results_chunk1 = self.infer(feats_chunk1, feats_len, cache) + + # last chunk + cache_en["last_chunk"] = True + feats_chunk2 = feats[:, -(feats.shape[1] + cache_en["chunk_size"][2] - cache_en["chunk_size"][1]):, :] + feats_len = torch.tensor([feats_chunk2.shape[1]]) + results_chunk2 = self.infer(feats_chunk2, feats_len, cache) + + return results_chunk1 + results_chunk2 + + results = self.infer(feats, feats_len, cache) + + return results + + @torch.no_grad() + def infer(self, feats: Union[torch.Tensor], feats_len: Union[torch.Tensor], cache: List = None): + batch = {"speech": feats, "speech_lengths": feats_len} batch = to_device(batch, device=self.device) - # b. Forward Encoder - enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache) + enc, enc_len = self.asr_model.encode_chunk(feats, feats_len, cache=cache) if isinstance(enc, tuple): enc = enc[0] # assert len(enc) == 1, len(enc) enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor predictor_outs = self.asr_model.calc_predictor_chunk(enc, cache) - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ - predictor_outs[2], predictor_outs[3] - pre_token_length = pre_token_length.floor().long() + pre_acoustic_embeds, pre_token_length= predictor_outs[0], predictor_outs[1] if torch.max(pre_token_length) < 1: return [] decoder_outs = self.asr_model.cal_decoder_with_predictor_chunk(enc, pre_acoustic_embeds, cache) @@ -279,166 +297,12 @@ class Speech2Text: text = self.tokenizer.tokens2text(token) else: text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) + results.append(text) # assert check_return_type(results) return results -class Speech2TextExport: - """Speech2TextExport class - - """ - - def __init__( - self, - asr_train_config: Union[Path, str] = None, - asr_model_file: Union[Path, str] = None, - cmvn_file: Union[Path, str] = None, - lm_train_config: Union[Path, str] = None, - lm_file: Union[Path, str] = None, - token_type: str = None, - bpemodel: str = None, - device: str = "cpu", - maxlenratio: float = 0.0, - minlenratio: float = 0.0, - dtype: str = "float32", - beam_size: int = 20, - ctc_weight: float = 0.5, - lm_weight: float = 1.0, - ngram_weight: float = 0.9, - penalty: float = 0.0, - nbest: int = 1, - frontend_conf: dict = None, - hotword_list_or_file: str = None, - **kwargs, - ): - - # 1. Build ASR model - asr_model, asr_train_args = ASRTask.build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, device - ) - frontend = None - if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: - frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) - - logging.info("asr_model: {}".format(asr_model)) - logging.info("asr_train_args: {}".format(asr_train_args)) - asr_model.to(dtype=getattr(torch, dtype)).eval() - - token_list = asr_model.token_list - - logging.info(f"Decoding device={device}, dtype={dtype}") - - # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text - if token_type is None: - token_type = asr_train_args.token_type - if bpemodel is None: - bpemodel = asr_train_args.bpemodel - - if token_type is None: - tokenizer = None - elif token_type == "bpe": - if bpemodel is not None: - tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) - else: - tokenizer = None - else: - tokenizer = build_tokenizer(token_type=token_type) - converter = TokenIDConverter(token_list=token_list) - logging.info(f"Text tokenizer: {tokenizer}") - - # self.asr_model = asr_model - self.asr_train_args = asr_train_args - self.converter = converter - self.tokenizer = tokenizer - - self.device = device - self.dtype = dtype - self.nbest = nbest - self.frontend = frontend - - model = Paraformer_export(asr_model, onnx=False) - self.asr_model = model - - @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ): - """Inference - - Args: - speech: Input speech data - Returns: - text, token, token_int, hyp - - """ - assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - - if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - self.asr_model.frontend = None - else: - feats = speech - feats_len = speech_lengths - - enc_len_batch_total = feats_len.sum() - lfr_factor = max(1, (feats.size()[-1] // 80) - 1) - batch = {"speech": feats, "speech_lengths": feats_len} - - # a. To device - batch = to_device(batch, device=self.device) - - decoder_outs = self.asr_model(**batch) - decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] - - results = [] - b, n, d = decoder_out.size() - for i in range(b): - am_scores = decoder_out[i, :ys_pad_lens[i], :] - - yseq = am_scores.argmax(dim=-1) - score = am_scores.max(dim=-1)[0] - score = torch.sum(score, dim=-1) - # pad with mask tokens to ensure compatibility with sos/eos tokens - yseq = torch.tensor( - yseq.tolist(), device=yseq.device - ) - nbest_hyps = [Hypothesis(yseq=yseq, score=score)] - - for hyp in nbest_hyps: - assert isinstance(hyp, (Hypothesis)), type(hyp) - - # remove sos/eos and get results - last_pos = -1 - if isinstance(hyp.yseq, list): - token_int = hyp.yseq[1:last_pos] - else: - token_int = hyp.yseq[1:last_pos].tolist() - - # remove blank symbol id, which is assumed to be 0 - token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) - - # Change integer-ids to tokens - token = self.converter.ids2tokens(token_int) - - if self.tokenizer is not None: - text = self.tokenizer.tokens2text(token) - else: - text = None - - results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) - - return results - - def inference( maxlenratio: float, minlenratio: float, @@ -536,8 +400,6 @@ def inference_modelscope( **kwargs, ): assert check_argument_types() - ncpu = kwargs.get("ncpu", 1) - torch.set_num_threads(ncpu) if word_lm_train_config is not None: raise NotImplementedError("Word LM is not implemented") @@ -580,11 +442,9 @@ def inference_modelscope( penalty=penalty, nbest=nbest, ) - if export_mode: - speech2text = Speech2TextExport(**speech2text_kwargs) - else: - speech2text = Speech2Text(**speech2text_kwargs) - + + speech2text = Speech2Text(**speech2text_kwargs) + def _load_bytes(input): middle_data = np.frombuffer(input, dtype=np.int16) middle_data = np.asarray(middle_data) @@ -599,7 +459,33 @@ def inference_modelscope( offset = i.min + abs_max array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) return array - + + def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): + if len(cache) > 0: + return cache + + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), + "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))} + cache["encoder"] = cache_en + + cache_de = {"decode_fsmn": None} + cache["decoder"] = cache_de + + return cache + + def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): + if len(cache) > 0: + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), + "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))} + cache["encoder"] = cache_en + + cache_de = {"decode_fsmn": None} + cache["decoder"] = cache_de + + return cache + def _forward( data_path_and_name_and_type, raw_inputs: Union[np.ndarray, torch.Tensor] = None, @@ -610,123 +496,35 @@ def inference_modelscope( ): # 3. Build data-iterator + if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": + raw_inputs = _load_bytes(data_path_and_name_and_type[0]) + raw_inputs = torch.tensor(raw_inputs) + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, np.ndarray): + raw_inputs = torch.tensor(raw_inputs) is_final = False cache = {} + chunk_size = [5, 10, 5] if param_dict is not None and "cache" in param_dict: cache = param_dict["cache"] if param_dict is not None and "is_final" in param_dict: is_final = param_dict["is_final"] + if param_dict is not None and "chunk_size" in param_dict: + chunk_size = param_dict["chunk_size"] - if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "bytes": - raw_inputs = _load_bytes(data_path_and_name_and_type[0]) - raw_inputs = torch.tensor(raw_inputs) - if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound": - raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0] - is_final = True - if data_path_and_name_and_type is None and raw_inputs is not None: - if isinstance(raw_inputs, np.ndarray): - raw_inputs = torch.tensor(raw_inputs) # 7 .Start for-loop # FIXME(kamo): The output format should be discussed about + raw_inputs = torch.unsqueeze(raw_inputs, axis=0) + input_lens = torch.tensor([raw_inputs.shape[1]]) asr_result_list = [] - results = [] - asr_result = "" - wait = True - if len(cache) == 0: - cache["encoder"] = {"start_idx": 0, "pad_left": 0, "stride": 10, "pad_right": 5, "cif_hidden": None, "cif_alphas": None, "is_final": is_final, "left": 0, "right": 0} - cache_de = {"decode_fsmn": None} - cache["decoder"] = cache_de - cache["first_chunk"] = True - cache["speech"] = [] - cache["accum_speech"] = 0 - if raw_inputs is not None: - if len(cache["speech"]) == 0: - cache["speech"] = raw_inputs - else: - cache["speech"] = torch.cat([cache["speech"], raw_inputs], dim=0) - cache["accum_speech"] += len(raw_inputs) - while cache["accum_speech"] >= 960: - if cache["first_chunk"]: - if cache["accum_speech"] >= 14400: - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 5 - cache["encoder"]["stride"] = 10 - cache["encoder"]["left"] = 5 - cache["encoder"]["right"] = 0 - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] -= 4800 - cache["first_chunk"] = False - cache["encoder"]["start_idx"] = -5 - cache["encoder"]["is_final"] = False - wait = False - else: - if is_final: - cache["encoder"]["stride"] = len(cache["speech"]) // 960 - cache["encoder"]["pad_left"] = 0 - cache["encoder"]["pad_right"] = 0 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] = 0 - wait = False - else: - break - else: - if cache["accum_speech"] >= 19200: - cache["encoder"]["start_idx"] += 10 - cache["encoder"]["stride"] = 10 - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 5 - cache["encoder"]["left"] = 0 - cache["encoder"]["right"] = 0 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] -= 9600 - wait = False - else: - if is_final: - cache["encoder"]["is_final"] = True - if cache["accum_speech"] >= 14400: - cache["encoder"]["start_idx"] += 10 - cache["encoder"]["stride"] = 10 - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 5 - cache["encoder"]["left"] = 0 - cache["encoder"]["right"] = cache["accum_speech"] // 960 - 15 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] -= 9600 - wait = False - else: - cache["encoder"]["start_idx"] += 10 - cache["encoder"]["stride"] = cache["accum_speech"] // 960 - 5 - cache["encoder"]["pad_left"] = 5 - cache["encoder"]["pad_right"] = 0 - cache["encoder"]["left"] = 0 - cache["encoder"]["right"] = 0 - speech = torch.unsqueeze(cache["speech"], axis=0) - speech_length = torch.tensor([len(cache["speech"])]) - results = speech2text(cache, speech, speech_length) - cache["accum_speech"] = 0 - wait = False - else: - break - - if len(results) >= 1: - asr_result += results[0][0] - if asr_result == "": - asr_result = "sil" - if wait: - asr_result = "waiting_for_more_voice" - item = {'key': "utt", 'value': asr_result} - asr_result_list.append(item) - else: - return [] + cache = _prepare_cache(cache, chunk_size=chunk_size, batch_size=1) + cache["encoder"]["is_final"] = is_final + asr_result = speech2text(cache, raw_inputs, input_lens) + item = {'key': "utt", 'value': asr_result} + asr_result_list.append(item) + if is_final: + cache = _cache_reset(cache, chunk_size=chunk_size, batch_size=1) return asr_result_list return _forward @@ -921,4 +719,3 @@ if __name__ == "__main__": # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') # print(rec_result) - diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index 699d85fdb..d02783f49 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -712,9 +712,9 @@ class ParaformerOnline(Paraformer): def calc_predictor_chunk(self, encoder_out, cache=None): - pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = \ + pre_acoustic_embeds, pre_token_length = \ self.predictor.forward_chunk(encoder_out, cache["encoder"]) - return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index + return pre_acoustic_embeds, pre_token_length def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None): decoder_outs = self.decoder.forward_chunk( diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index f2502bbb6..7d84ad5f1 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -6,9 +6,11 @@ from typing import Union import logging import torch import torch.nn as nn +import torch.nn.functional as F from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk from typeguard import check_argument_types import numpy as np +from funasr.torch_utils.device_funcs import to_device from funasr.modules.nets_utils import make_pad_mask from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder @@ -349,6 +351,23 @@ class SANMEncoder(AbsEncoder): return (xs_pad, intermediate_outs), olens, None return xs_pad, olens, None + def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}): + if len(cache) == 0: + return feats + # process last chunk + cache["feats"] = to_device(cache["feats"], device=feats.device) + overlap_feats = torch.cat((cache["feats"], feats), dim=1) + if cache["is_final"]: + cache["feats"] = overlap_feats[:, -cache["chunk_size"][0]:, :] + if not cache["last_chunk"]: + padding_length = sum(cache["chunk_size"]) - overlap_feats.shape[1] + overlap_feats = overlap_feats.transpose(1, 2) + overlap_feats = F.pad(overlap_feats, (0, padding_length)) + overlap_feats = overlap_feats.transpose(1, 2) + else: + cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] + return overlap_feats + def forward_chunk(self, xs_pad: torch.Tensor, ilens: torch.Tensor, @@ -360,7 +379,7 @@ class SANMEncoder(AbsEncoder): xs_pad = xs_pad else: xs_pad = self.embed(xs_pad, cache) - + xs_pad = self._add_overlap_chunk(xs_pad, cache) encoder_outs = self.encoders0(xs_pad, None, None, None, None) xs_pad, masks = encoder_outs[0], encoder_outs[1] intermediate_outs = [] diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py index a5273f841..c59e24502 100644 --- a/funasr/models/predictor/cif.py +++ b/funasr/models/predictor/cif.py @@ -2,6 +2,7 @@ import torch from torch import nn import logging import numpy as np +from funasr.torch_utils.device_funcs import to_device from funasr.modules.nets_utils import make_pad_mask from funasr.modules.streaming_utils.utils import sequence_mask @@ -200,7 +201,7 @@ class CifPredictorV2(nn.Module): return acoustic_embeds, token_num, alphas, cif_peak def forward_chunk(self, hidden, cache=None): - b, t, d = hidden.size() + batch_size, len_time, hidden_size = hidden.shape h = hidden context = h.transpose(1, 2) queries = self.pad(context) @@ -211,58 +212,81 @@ class CifPredictorV2(nn.Module): alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold) alphas = alphas.squeeze(-1) - mask_chunk_predictor = None - if cache is not None: - mask_chunk_predictor = None - mask_chunk_predictor = torch.zeros_like(alphas) - mask_chunk_predictor[:, cache["pad_left"]:cache["stride"] + cache["pad_left"]] = 1.0 - - if mask_chunk_predictor is not None: - alphas = alphas * mask_chunk_predictor - - if cache is not None: - if cache["is_final"]: - alphas[:, cache["stride"] + cache["pad_left"] - 1] += 0.45 - if cache["cif_hidden"] is not None: - hidden = torch.cat((cache["cif_hidden"], hidden), 1) - if cache["cif_alphas"] is not None: - alphas = torch.cat((cache["cif_alphas"], alphas), -1) - token_num = alphas.sum(-1) - acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) - len_time = alphas.size(-1) - last_fire_place = len_time - 1 - last_fire_remainds = 0.0 - pre_alphas_length = 0 - last_fire = False - - mask_chunk_peak_predictor = None - if cache is not None: - mask_chunk_peak_predictor = None - mask_chunk_peak_predictor = torch.zeros_like(cif_peak) - if cache["cif_alphas"] is not None: - pre_alphas_length = cache["cif_alphas"].size(-1) - mask_chunk_peak_predictor[:, :pre_alphas_length] = 1.0 - mask_chunk_peak_predictor[:, pre_alphas_length + cache["pad_left"]:pre_alphas_length + cache["stride"] + cache["pad_left"]] = 1.0 - - if mask_chunk_peak_predictor is not None: - cif_peak = cif_peak * mask_chunk_peak_predictor.squeeze(-1) - - for i in range(len_time): - if cif_peak[0][len_time - 1 - i] > self.threshold or cif_peak[0][len_time - 1 - i] == self.threshold: - last_fire_place = len_time - 1 - i - last_fire_remainds = cif_peak[0][len_time - 1 - i] - self.threshold - last_fire = True - break - if last_fire: - last_fire_remainds = torch.tensor([last_fire_remainds], dtype=alphas.dtype).to(alphas.device) - cache["cif_hidden"] = hidden[:, last_fire_place:, :] - cache["cif_alphas"] = torch.cat((last_fire_remainds.unsqueeze(0), alphas[:, last_fire_place+1:]), -1) - else: - cache["cif_hidden"] = hidden - cache["cif_alphas"] = alphas - token_num_int = token_num.floor().type(torch.int32).item() - return acoustic_embeds[:, 0:token_num_int, :], token_num, alphas, cif_peak + token_length = [] + list_fires = [] + list_frames = [] + cache_alphas = [] + cache_hiddens = [] + + if cache is not None and "chunk_size" in cache: + alphas[:, :cache["chunk_size"][0]] = 0.0 + alphas[:, sum(cache["chunk_size"][:2]):] = 0.0 + if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache: + cache["cif_hidden"] = to_device(cache["cif_hidden"], device=hidden.device) + cache["cif_alphas"] = to_device(cache["cif_alphas"], device=alphas.device) + hidden = torch.cat((cache["cif_hidden"], hidden), dim=1) + alphas = torch.cat((cache["cif_alphas"], alphas), dim=1) + if cache is not None and "last_chunk" in cache and cache["last_chunk"]: + tail_hidden = torch.zeros((batch_size, 1, hidden_size), device=hidden.device) + tail_alphas = torch.tensor([[self.tail_threshold]], device=alphas.device) + tail_alphas = torch.tile(tail_alphas, (batch_size, 1)) + hidden = torch.cat((hidden, tail_hidden), dim=1) + alphas = torch.cat((alphas, tail_alphas), dim=1) + + len_time = alphas.shape[1] + for b in range(batch_size): + integrate = 0.0 + frames = torch.zeros((hidden_size), device=hidden.device) + list_frame = [] + list_fire = [] + for t in range(len_time): + alpha = alphas[b][t] + if alpha + integrate < self.threshold: + integrate += alpha + list_fire.append(integrate) + frames += alpha * hidden[b][t] + else: + frames += (self.threshold - integrate) * hidden[b][t] + list_frame.append(frames) + integrate += alpha + list_fire.append(integrate) + integrate -= self.threshold + frames = integrate * hidden[b][t] + + cache_alphas.append(integrate) + if integrate > 0.0: + cache_hiddens.append(frames / integrate) + else: + cache_hiddens.append(frames) + + token_length.append(torch.tensor(len(list_frame), device=alphas.device)) + list_fires.append(list_fire) + list_frames.append(list_frame) + + cache["cif_alphas"] = torch.stack(cache_alphas, axis=0) + cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) + cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) + cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) + + max_token_len = max(token_length) + if max_token_len == 0: + return hidden, torch.stack(token_length, 0) + list_ls = [] + for b in range(batch_size): + pad_frames = torch.zeros((max_token_len - token_length[b], hidden_size), device=alphas.device) + if token_length[b] == 0: + list_ls.append(pad_frames) + else: + list_frames[b] = torch.stack(list_frames[b]) + list_ls.append(torch.cat((list_frames[b], pad_frames), dim=0)) + + cache["cif_alphas"] = torch.stack(cache_alphas, axis=0) + cache["cif_alphas"] = torch.unsqueeze(cache["cif_alphas"], axis=0) + cache["cif_hidden"] = torch.stack(cache_hiddens, axis=0) + cache["cif_hidden"] = torch.unsqueeze(cache["cif_hidden"], axis=0) + return torch.stack(list_ls, 0), torch.stack(token_length, 0) + def tail_process_fn(self, hidden, alphas, token_num=None, mask=None): b, t, d = hidden.size() diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py index c347e24f1..aaac80a7d 100644 --- a/funasr/modules/embedding.py +++ b/funasr/modules/embedding.py @@ -425,21 +425,14 @@ class StreamSinusoidalPositionEncoder(torch.nn.Module): return encoding.type(dtype) def forward(self, x, cache=None): - start_idx = 0 - pad_left = 0 - pad_right = 0 batch_size, timesteps, input_dim = x.size() + start_idx = 0 if cache is not None: start_idx = cache["start_idx"] - pad_left = cache["left"] - pad_right = cache["right"] + cache["start_idx"] += timesteps positions = torch.arange(1, timesteps+start_idx+1)[None, :] position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) - outputs = x + position_encoding[:, start_idx: start_idx + timesteps] - outputs = outputs.transpose(1, 2) - outputs = F.pad(outputs, (pad_left, pad_right)) - outputs = outputs.transpose(1, 2) - return outputs + return x + position_encoding[:, start_idx: start_idx + timesteps] class StreamingRelPositionalEncoding(torch.nn.Module): """Relative positional encoding. From 6abf68388ce81bb8aa8dfa29f5f68b10d88cae05 Mon Sep 17 00:00:00 2001 From: hnluo Date: Thu, 27 Apr 2023 01:15:59 +0800 Subject: [PATCH 06/15] Update asr_inference_paraformer_streaming.py --- funasr/bin/asr_inference_paraformer_streaming.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index 939ffe99f..63ef3a34a 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -202,9 +202,9 @@ class Speech2Text: assert check_argument_types() results = [] cache_en = cache["encoder"] - if speech.shape[1] < 16 * 60 and cache["is_final"]: - cache["last_chunk"] = True - feats = cache["feats"] + if speech.shape[1] < 16 * 60 and cache_en["is_final"]: + cache_en["last_chunk"] = True + feats = cache_en["feats"] feats_len = torch.tensor([feats.shape[1]]) else: if self.frontend is not None: From c5992ca03ea6d6c7b78e5c1d481a612d0f91ac21 Mon Sep 17 00:00:00 2001 From: hnluo Date: Thu, 27 Apr 2023 01:47:46 +0800 Subject: [PATCH 07/15] Update asr_inference_paraformer_streaming.py --- funasr/bin/asr_inference_paraformer_streaming.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index 63ef3a34a..3f13982f1 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -203,7 +203,7 @@ class Speech2Text: results = [] cache_en = cache["encoder"] if speech.shape[1] < 16 * 60 and cache_en["is_final"]: - cache_en["last_chunk"] = True + cache_en["tail_chunk"] = True feats = cache_en["feats"] feats_len = torch.tensor([feats.shape[1]]) else: @@ -232,7 +232,7 @@ class Speech2Text: feats_len = torch.tensor([feats_chunk2.shape[1]]) results_chunk2 = self.infer(feats_chunk2, feats_len, cache) - return results_chunk1 + results_chunk2 + return ["".join(results_chunk1 + results_chunk2)] results = self.infer(feats, feats_len, cache) @@ -466,7 +466,7 @@ def inference_modelscope( cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, - "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))} + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False} cache["encoder"] = cache_en cache_de = {"decode_fsmn": None} @@ -478,7 +478,7 @@ def inference_modelscope( if len(cache) > 0: cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, - "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560))} + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False} cache["encoder"] = cache_en cache_de = {"decode_fsmn": None} From 9ff5b683db460a428d1804f64c25071ae2172ea7 Mon Sep 17 00:00:00 2001 From: hnluo Date: Thu, 27 Apr 2023 01:49:01 +0800 Subject: [PATCH 08/15] Update sanm_encoder.py --- funasr/models/encoder/sanm_encoder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index 7d84ad5f1..969ddadf2 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -379,7 +379,10 @@ class SANMEncoder(AbsEncoder): xs_pad = xs_pad else: xs_pad = self.embed(xs_pad, cache) - xs_pad = self._add_overlap_chunk(xs_pad, cache) + if cache["tail_chunk"]: + xs_pad = cache["feats"] + else: + xs_pad = self._add_overlap_chunk(xs_pad, cache) encoder_outs = self.encoders0(xs_pad, None, None, None, None) xs_pad, masks = encoder_outs[0], encoder_outs[1] intermediate_outs = [] From 493dda8f98662a546e908c08039898d0351782fd Mon Sep 17 00:00:00 2001 From: hnluo Date: Thu, 27 Apr 2023 01:57:42 +0800 Subject: [PATCH 09/15] Update asr_inference_paraformer_streaming.py --- funasr/bin/asr_inference_paraformer_streaming.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index 3f13982f1..c70baf023 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -206,6 +206,8 @@ class Speech2Text: cache_en["tail_chunk"] = True feats = cache_en["feats"] feats_len = torch.tensor([feats.shape[1]]) + results = self.infer(feats, feats_len, cache) + return results else: if self.frontend is not None: feats, feats_len = self.frontend.forward(speech, speech_lengths, cache_en["is_final"]) From 7e0652f8d5701e5952a1c81770de4e06e0019f9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 27 Apr 2023 10:30:13 +0800 Subject: [PATCH 10/15] websocket --- funasr/runtime/python/websocket/ASR_client.py | 38 ++- .../python/websocket/ASR_server_streaming.py | 261 ++++++++++++++++++ .../websocket/ASR_server_streaming_asr.py | 149 ++++++++++ 3 files changed, 438 insertions(+), 10 deletions(-) create mode 100644 funasr/runtime/python/websocket/ASR_server_streaming.py create mode 100644 funasr/runtime/python/websocket/ASR_server_streaming_asr.py diff --git a/funasr/runtime/python/websocket/ASR_client.py b/funasr/runtime/python/websocket/ASR_client.py index cc0e7b6e4..b0abfc793 100644 --- a/funasr/runtime/python/websocket/ASR_client.py +++ b/funasr/runtime/python/websocket/ASR_client.py @@ -1,5 +1,4 @@ - -# import websocket #区别服务端这里是 websocket-client库 +# -*- encoding: utf-8 -*- import time import websockets import asyncio @@ -50,18 +49,21 @@ async def record_microphone(): rate=RATE, input=True, frames_per_buffer=CHUNK) - + is_speaking = True while True: data = stream.read(CHUNK) + data = data.decode('ISO-8859-1') + message = json.dumps({"chunk": args.chunk_size, "is_speaking": is_speaking, "audio": data}) - voices.put(data) + voices.put(message) #print(voices.qsize()) await asyncio.sleep(0.01) # 其他函数可以通过调用send(data)来发送数据,例如: async def record_from_scp(): + import wave global voices if args.audio_in.endswith(".scp"): f_scp = open(args.audio_in) @@ -71,15 +73,31 @@ async def record_from_scp(): for wav in wavs: wav_splits = wav.strip().split() wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0] - bytes = open(wav_path, "rb") - bytes = bytes.read() - + # bytes_f = open(wav_path, "rb") + # bytes_data = bytes_f.read() + with wave.open(wav_path, "rb") as wav_file: + # 获取音频参数 + params = wav_file.getparams() + # 获取头信息的长度 + # header_length = wav_file.getheaders()[0][1] + # 读取音频帧数据,跳过头信息 + # wav_file.setpos(header_length) + frames = wav_file.readframes(wav_file.getnframes()) + + # 将音频帧数据转换为字节类型的数据 + audio_bytes = bytes(frames) stride = int(args.chunk_size/1000*16000*2) - chunk_num = (len(bytes)-1)//stride + 1 + chunk_num = (len(audio_bytes)-1)//stride + 1 + print(stride) + is_speaking = True for i in range(chunk_num): + if i == chunk_num-1: + is_speaking = False beg = i*stride - data_chunk = bytes[beg:beg+stride] - voices.put(data_chunk) + data = audio_bytes[beg:beg+stride] + data = data.decode('ISO-8859-1') + message = json.dumps({"chunk": args.chunk_size, "is_speaking": is_speaking, "audio": data}) + voices.put(message) # print("data_chunk: ", len(data_chunk)) # print(voices.qsize()) diff --git a/funasr/runtime/python/websocket/ASR_server_streaming.py b/funasr/runtime/python/websocket/ASR_server_streaming.py new file mode 100644 index 000000000..b7c54f78c --- /dev/null +++ b/funasr/runtime/python/websocket/ASR_server_streaming.py @@ -0,0 +1,261 @@ +import asyncio +import json +import websockets +import time +from queue import Queue +import threading +import argparse + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +import logging +import tracemalloc +import numpy as np + +tracemalloc.start() + +logger = get_logger(log_level=logging.CRITICAL) +logger.setLevel(logging.CRITICAL) + + +websocket_users = set() #维护客户端列表 + +parser = argparse.ArgumentParser() +parser.add_argument("--host", + type=str, + default="0.0.0.0", + required=False, + help="host ip, localhost, 0.0.0.0") +parser.add_argument("--port", + type=int, + default=10095, + required=False, + help="grpc server port") +parser.add_argument("--asr_model", + type=str, + default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + help="model from modelscope") +parser.add_argument("--vad_model", + type=str, + default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + help="model from modelscope") + +parser.add_argument("--punc_model", + type=str, + default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", + help="model from modelscope") +parser.add_argument("--ngpu", + type=int, + default=1, + help="0 for cpu, 1 for gpu") + +args = parser.parse_args() + +print("model loading") + +def load_bytes(input): + middle_data = np.frombuffer(input, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in 'iu': + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype('float32') + if dtype.kind != 'f': + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) + return array + +# vad +inference_pipeline_vad = pipeline( + task=Tasks.voice_activity_detection, + model=args.vad_model, + model_revision=None, + output_dir=None, + batch_size=1, + mode='online', + ngpu=args.ngpu, +) +# param_dict_vad = {'in_cache': dict(), "is_final": False} + +# # asr +# param_dict_asr = {} +# # param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开 +# inference_pipeline_asr = pipeline( +# task=Tasks.auto_speech_recognition, +# model=args.asr_model, +# param_dict=param_dict_asr, +# ngpu=args.ngpu, +# ) +# if args.punc_model != "": +# # param_dict_punc = {'cache': list()} +# inference_pipeline_punc = pipeline( +# task=Tasks.punctuation, +# model=args.punc_model, +# model_revision=None, +# ngpu=args.ngpu, +# ) +# else: +# inference_pipeline_punc = None + + +inference_pipeline_asr_online = pipeline( + task=Tasks.auto_speech_recognition, + model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online', + model_revision=None) + + +print("model loaded") + + + +async def ws_serve(websocket, path): + #speek = Queue() + frames = [] # 存储所有的帧数据 + frames_online = [] # 存储所有的帧数据 + buffer = [] # 存储缓存中的帧数据(最多两个片段) + RECORD_NUM = 0 + global websocket_users + speech_start, speech_end = False, False + # 调用asr函数 + websocket.param_dict_vad = {'in_cache': dict(), "is_final": False} + websocket.param_dict_punc = {'cache': list()} + websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包 + websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端 + websocket_users.add(websocket) + # ss = threading.Thread(target=asr, args=(websocket,)) + # ss.start() + + websocket.param_dict_asr_online = {"cache": dict(), "is_final": False} + websocket.speek_online = Queue() # websocket 添加进队列对象 让asr读取语音数据包 + ss_online = threading.Thread(target=asr_online, args=(websocket,)) + ss_online.start() + + try: + async for data in websocket: + #voices.put(message) + #print("put") + #await websocket.send("123") + + data = json.loads(data) + # message = data["data"] + message = bytes(data['audio'], 'ISO-8859-1') + chunk = data["chunk"] + chunk_num = 600//chunk + is_speaking = data["is_speaking"] + websocket.param_dict_vad["is_final"] = not is_speaking + buffer.append(message) + if len(buffer) > 2: + buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 + + if speech_start: + # frames.append(message) + frames_online.append(message) + # RECORD_NUM += 1 + if len(frames_online) % chunk_num == 0: + audio_in = b"".join(frames_online) + websocket.speek_online.put(audio_in) + frames_online = [] + + speech_start_i, speech_end_i = vad(message, websocket) + #print(speech_start_i, speech_end_i) + if speech_start_i: + # RECORD_NUM += 1 + speech_start = speech_start_i + # frames = [] + # frames.extend(buffer) # 把之前2个语音数据快加入 + frames_online = [] + # frames_online.append(message) + frames_online.extend(buffer) + # RECORD_NUM += 1 + websocket.param_dict_asr_online["is_final"] = False + if speech_end_i: + speech_start = False + # audio_in = b"".join(frames) + # websocket.speek.put(audio_in) + # frames = [] # 清空所有的帧数据 + frames_online = [] + websocket.param_dict_asr_online["is_final"] = True + # buffer = [] # 清空缓存中的帧数据(最多两个片段) + # RECORD_NUM = 0 + if not websocket.send_msg.empty(): + await websocket.send(websocket.send_msg.get()) + websocket.send_msg.task_done() + + + except websockets.ConnectionClosed: + print("ConnectionClosed...", websocket_users) # 链接断开 + websocket_users.remove(websocket) + except websockets.InvalidState: + print("InvalidState...") # 无效状态 + except Exception as e: + print("Exception:", e) + + +# def asr(websocket): # ASR推理 +# global inference_pipeline_asr +# # global param_dict_punc +# global websocket_users +# while websocket in websocket_users: +# if not websocket.speek.empty(): +# audio_in = websocket.speek.get() +# websocket.speek.task_done() +# if len(audio_in) > 0: +# rec_result = inference_pipeline_asr(audio_in=audio_in) +# if inference_pipeline_punc is not None and 'text' in rec_result: +# rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc) +# # print(rec_result) +# if "text" in rec_result: +# message = json.dumps({"mode": "offline", "text": rec_result["text"]}) +# websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 +# +# time.sleep(0.1) + + +def asr_online(websocket): # ASR推理 + global inference_pipeline_asr_online + # global param_dict_punc + global websocket_users + while websocket in websocket_users: + if not websocket.speek_online.empty(): + audio_in = websocket.speek_online.get() + websocket.speek_online.task_done() + if len(audio_in) > 0: + # print(len(audio_in)) + audio_in = load_bytes(audio_in) + # print(audio_in.shape) + rec_result = inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online) + + # print(rec_result) + if "text" in rec_result: + if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice": + message = json.dumps({"mode": "online", "text": rec_result["text"]}) + websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 + + time.sleep(0.1) + +def vad(data, websocket): # VAD推理 + global inference_pipeline_vad, param_dict_vad + #print(type(data)) + # print(param_dict_vad) + segments_result = inference_pipeline_vad(audio_in=data, param_dict=websocket.param_dict_vad) + # print(segments_result) + # print(param_dict_vad) + speech_start = False + speech_end = False + + if len(segments_result) == 0 or len(segments_result["text"]) > 1: + return speech_start, speech_end + if segments_result["text"][0][0] != -1: + speech_start = True + if segments_result["text"][0][1] != -1: + speech_end = True + return speech_start, speech_end + + +start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() \ No newline at end of file diff --git a/funasr/runtime/python/websocket/ASR_server_streaming_asr.py b/funasr/runtime/python/websocket/ASR_server_streaming_asr.py new file mode 100644 index 000000000..396597ee8 --- /dev/null +++ b/funasr/runtime/python/websocket/ASR_server_streaming_asr.py @@ -0,0 +1,149 @@ +import asyncio +import json +import websockets +import time +from queue import Queue +import threading +import argparse + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +import logging +import tracemalloc +import numpy as np + +tracemalloc.start() + +logger = get_logger(log_level=logging.CRITICAL) +logger.setLevel(logging.CRITICAL) + + +websocket_users = set() #维护客户端列表 + +parser = argparse.ArgumentParser() +parser.add_argument("--host", + type=str, + default="0.0.0.0", + required=False, + help="host ip, localhost, 0.0.0.0") +parser.add_argument("--port", + type=int, + default=10095, + required=False, + help="grpc server port") +parser.add_argument("--asr_model", + type=str, + default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + help="model from modelscope") +parser.add_argument("--vad_model", + type=str, + default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + help="model from modelscope") + +parser.add_argument("--punc_model", + type=str, + default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", + help="model from modelscope") +parser.add_argument("--ngpu", + type=int, + default=1, + help="0 for cpu, 1 for gpu") + +args = parser.parse_args() + +print("model loading") + +def load_bytes(input): + middle_data = np.frombuffer(input, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in 'iu': + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype('float32') + if dtype.kind != 'f': + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) + return array + +inference_pipeline_asr_online = pipeline( + task=Tasks.auto_speech_recognition, + # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online', + model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online', + model_revision=None) + + +print("model loaded") + + + +async def ws_serve(websocket, path): + frames_online = [] + global websocket_users + websocket.send_msg = Queue() + websocket_users.add(websocket) + websocket.param_dict_asr_online = {"cache": dict()} + websocket.speek_online = Queue() + ss_online = threading.Thread(target=asr_online, args=(websocket,)) + ss_online.start() + try: + async for message in websocket: + message = json.loads(message) + audio = bytes(message['audio'], 'ISO-8859-1') + chunk = message["chunk"] + chunk_num = 500//chunk + is_speaking = message["is_speaking"] + websocket.param_dict_asr_online["is_final"] = not is_speaking + frames_online.append(audio) + + if len(frames_online) % chunk_num == 0 or not is_speaking: + audio_in = b"".join(frames_online) + websocket.speek_online.put(audio_in) + frames_online = [] + + if not websocket.send_msg.empty(): + await websocket.send(websocket.send_msg.get()) + websocket.send_msg.task_done() + + + except websockets.ConnectionClosed: + print("ConnectionClosed...", websocket_users) # 链接断开 + websocket_users.remove(websocket) + except websockets.InvalidState: + print("InvalidState...") # 无效状态 + except Exception as e: + print("Exception:", e) + + + +def asr_online(websocket): # ASR推理 + global inference_pipeline_asr_online + global websocket_users + while websocket in websocket_users: + if not websocket.speek_online.empty(): + audio_in = websocket.speek_online.get() + websocket.speek_online.task_done() + if len(audio_in) > 0: + # print(len(audio_in)) + audio_in = load_bytes(audio_in) + # print(audio_in.shape) + print(websocket.param_dict_asr_online["is_final"]) + rec_result = inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online) + if websocket.param_dict_asr_online["is_final"]: + websocket.param_dict_asr_online["cache"] = dict() + + print(rec_result) + if "text" in rec_result: + if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice": + message = json.dumps({"mode": "online", "text": rec_result["text"]}) + websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 + + time.sleep(0.005) + + +start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() \ No newline at end of file From 9624eba825069e64a64fb40dc01df51063e9271f Mon Sep 17 00:00:00 2001 From: hnluo Date: Thu, 27 Apr 2023 10:46:18 +0800 Subject: [PATCH 11/15] Update asr_inference_paraformer_streaming.py --- .../bin/asr_inference_paraformer_streaming.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer_streaming.py b/funasr/bin/asr_inference_paraformer_streaming.py index c70baf023..ff8bb8c77 100644 --- a/funasr/bin/asr_inference_paraformer_streaming.py +++ b/funasr/bin/asr_inference_paraformer_streaming.py @@ -8,6 +8,7 @@ import os import codecs import tempfile import requests +import yaml from pathlib import Path from typing import Optional from typing import Sequence @@ -462,13 +463,23 @@ def inference_modelscope( array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) return array + def _read_yaml(yaml_path: Union[str, Path]) -> Dict: + if not Path(yaml_path).exists(): + raise FileExistsError(f'The {yaml_path} does not exist.') + + with open(str(yaml_path), 'rb') as f: + data = yaml.load(f, Loader=yaml.Loader) + return data + def _prepare_cache(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): if len(cache) > 0: return cache - - cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), + config = _read_yaml(asr_train_config) + enc_output_size = config["encoder_conf"]["output_size"] + feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"] + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, - "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False} + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False} cache["encoder"] = cache_en cache_de = {"decode_fsmn": None} @@ -478,9 +489,12 @@ def inference_modelscope( def _cache_reset(cache: dict = {}, chunk_size=[5,10,5], batch_size=1): if len(cache) > 0: - cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, 320)), + config = _read_yaml(asr_train_config) + enc_output_size = config["encoder_conf"]["output_size"] + feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"] + cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)), "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False, - "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], 560)), "tail_chunk": False} + "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False} cache["encoder"] = cache_en cache_de = {"decode_fsmn": None} @@ -720,4 +734,3 @@ if __name__ == "__main__": # # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') # print(rec_result) - From a917d7557dd2b1e5263eeba7e5e4d5a5fc02f69f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 27 Apr 2023 11:41:16 +0800 Subject: [PATCH 12/15] websocket --- funasr/runtime/python/websocket/ASR_client.py | 6 ++--- .../websocket/ASR_server_streaming_asr.py | 26 ++++++++++++++----- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/funasr/runtime/python/websocket/ASR_client.py b/funasr/runtime/python/websocket/ASR_client.py index b0abfc793..9a4a14802 100644 --- a/funasr/runtime/python/websocket/ASR_client.py +++ b/funasr/runtime/python/websocket/ASR_client.py @@ -59,7 +59,7 @@ async def record_microphone(): voices.put(message) #print(voices.qsize()) - await asyncio.sleep(0.01) + await asyncio.sleep(0.005) # 其他函数可以通过调用send(data)来发送数据,例如: async def record_from_scp(): @@ -116,8 +116,8 @@ async def ws_send(): await websocket.send(data) # 通过ws对象发送数据 except Exception as e: print('Exception occurred:', e) - await asyncio.sleep(0.01) - await asyncio.sleep(0.01) + await asyncio.sleep(0.005) + await asyncio.sleep(0.005) diff --git a/funasr/runtime/python/websocket/ASR_server_streaming_asr.py b/funasr/runtime/python/websocket/ASR_server_streaming_asr.py index 396597ee8..b8b8b8d50 100644 --- a/funasr/runtime/python/websocket/ASR_server_streaming_asr.py +++ b/funasr/runtime/python/websocket/ASR_server_streaming_asr.py @@ -89,6 +89,8 @@ async def ws_serve(websocket, path): websocket.speek_online = Queue() ss_online = threading.Thread(target=asr_online, args=(websocket,)) ss_online.start() + ss_ws_send = threading.Thread(target=ws_send, args=(websocket,)) + ss_ws_send.start() try: async for message in websocket: message = json.loads(message) @@ -104,9 +106,9 @@ async def ws_serve(websocket, path): websocket.speek_online.put(audio_in) frames_online = [] - if not websocket.send_msg.empty(): - await websocket.send(websocket.send_msg.get()) - websocket.send_msg.task_done() + # if not websocket.send_msg.empty(): + # await websocket.send(websocket.send_msg.get()) + # websocket.send_msg.task_done() except websockets.ConnectionClosed: @@ -119,11 +121,20 @@ async def ws_serve(websocket, path): -def asr_online(websocket): # ASR推理 +def ws_send(websocket): # ASR推理 global inference_pipeline_asr_online global websocket_users while websocket in websocket_users: if not websocket.speek_online.empty(): + await websocket.send(websocket.send_msg.get()) + websocket.send_msg.task_done() + time.sleep(0.005) + + +def asr_online(websocket): # ASR推理 + global websocket_users + while websocket in websocket_users: + if not websocket.send_msg.empty(): audio_in = websocket.speek_online.get() websocket.speek_online.task_done() if len(audio_in) > 0: @@ -131,10 +142,11 @@ def asr_online(websocket): # ASR推理 audio_in = load_bytes(audio_in) # print(audio_in.shape) print(websocket.param_dict_asr_online["is_final"]) - rec_result = inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online) + rec_result = inference_pipeline_asr_online(audio_in=audio_in, + param_dict=websocket.param_dict_asr_online) if websocket.param_dict_asr_online["is_final"]: websocket.param_dict_asr_online["cache"] = dict() - + print(rec_result) if "text" in rec_result: if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice": @@ -143,7 +155,7 @@ def asr_online(websocket): # ASR推理 time.sleep(0.005) - + start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) asyncio.get_event_loop().run_until_complete(start_server) asyncio.get_event_loop().run_forever() \ No newline at end of file From ba8d73d57db031fa7a1265d2c837ff694d5c5c93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 27 Apr 2023 16:39:01 +0800 Subject: [PATCH 13/15] websocket --- .gitignore | 3 +- funasr/runtime/python/websocket/ASR_server.py | 187 ------------- .../python/websocket/ASR_server_2pass.py | 252 ----------------- .../python/websocket/ASR_server_streaming.py | 261 ------------------ .../websocket/ASR_server_streaming_asr.py | 161 ----------- funasr/runtime/python/websocket/parse_args.py | 35 +++ .../websocket/{ASR_client.py => ws_client.py} | 54 +++- .../python/websocket/ws_server_online.py | 108 ++++++++ 8 files changed, 187 insertions(+), 874 deletions(-) delete mode 100644 funasr/runtime/python/websocket/ASR_server.py delete mode 100644 funasr/runtime/python/websocket/ASR_server_2pass.py delete mode 100644 funasr/runtime/python/websocket/ASR_server_streaming.py delete mode 100644 funasr/runtime/python/websocket/ASR_server_streaming_asr.py create mode 100644 funasr/runtime/python/websocket/parse_args.py rename funasr/runtime/python/websocket/{ASR_client.py => ws_client.py} (73%) create mode 100644 funasr/runtime/python/websocket/ws_server_online.py diff --git a/.gitignore b/.gitignore index 33b8c3979..b0fa5430b 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ MaaS-lib .egg* dist build -funasr.egg-info \ No newline at end of file +funasr.egg-info +sherpa \ No newline at end of file diff --git a/funasr/runtime/python/websocket/ASR_server.py b/funasr/runtime/python/websocket/ASR_server.py deleted file mode 100644 index c717e7126..000000000 --- a/funasr/runtime/python/websocket/ASR_server.py +++ /dev/null @@ -1,187 +0,0 @@ -import asyncio -import websockets -import time -from queue import Queue -import threading -import argparse -import json - -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.logger import get_logger -import logging -import tracemalloc -tracemalloc.start() - -logger = get_logger(log_level=logging.CRITICAL) -logger.setLevel(logging.CRITICAL) - - -websocket_users = set() #维护客户端列表 - -parser = argparse.ArgumentParser() -parser.add_argument("--host", - type=str, - default="0.0.0.0", - required=False, - help="host ip, localhost, 0.0.0.0") -parser.add_argument("--port", - type=int, - default=10095, - required=False, - help="grpc server port") -parser.add_argument("--asr_model", - type=str, - default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - help="model from modelscope") -parser.add_argument("--vad_model", - type=str, - default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - help="model from modelscope") - -parser.add_argument("--punc_model", - type=str, - default="", - help="model from modelscope") -parser.add_argument("--ngpu", - type=int, - default=1, - help="0 for cpu, 1 for gpu") - -args = parser.parse_args() - -print("model loading") - - -# vad -inference_pipeline_vad = pipeline( - task=Tasks.voice_activity_detection, - model=args.vad_model, - model_revision=None, - output_dir=None, - batch_size=1, - mode='online', - ngpu=args.ngpu, -) -# param_dict_vad = {'in_cache': dict(), "is_final": False} - -# asr -param_dict_asr = {} -# param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开 -inference_pipeline_asr = pipeline( - task=Tasks.auto_speech_recognition, - model=args.asr_model, - param_dict=param_dict_asr, - ngpu=args.ngpu, -) -if args.punc_model != "": - # param_dict_punc = {'cache': list()} - inference_pipeline_punc = pipeline( - task=Tasks.punctuation, - model=args.punc_model, - model_revision=None, - ngpu=args.ngpu, - ) -else: - inference_pipeline_punc = None - -print("model loaded") - - - -async def ws_serve(websocket, path): - #speek = Queue() - frames = [] # 存储所有的帧数据 - buffer = [] # 存储缓存中的帧数据(最多两个片段) - RECORD_NUM = 0 - global websocket_users - speech_start, speech_end = False, False - # 调用asr函数 - websocket.param_dict_vad = {'in_cache': dict(), "is_final": False} - websocket.param_dict_punc = {'cache': list()} - websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包 - websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端 - websocket_users.add(websocket) - ss = threading.Thread(target=asr, args=(websocket,)) - ss.start() - - try: - async for message in websocket: - #voices.put(message) - #print("put") - #await websocket.send("123") - buffer.append(message) - if len(buffer) > 2: - buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 - - if speech_start: - frames.append(message) - RECORD_NUM += 1 - speech_start_i, speech_end_i = vad(message, websocket) - #print(speech_start_i, speech_end_i) - if speech_start_i: - speech_start = speech_start_i - frames = [] - frames.extend(buffer) # 把之前2个语音数据快加入 - if speech_end_i or RECORD_NUM > 300: - speech_start = False - audio_in = b"".join(frames) - websocket.speek.put(audio_in) - frames = [] # 清空所有的帧数据 - buffer = [] # 清空缓存中的帧数据(最多两个片段) - RECORD_NUM = 0 - if not websocket.send_msg.empty(): - await websocket.send(websocket.send_msg.get()) - websocket.send_msg.task_done() - - - except websockets.ConnectionClosed: - print("ConnectionClosed...", websocket_users) # 链接断开 - websocket_users.remove(websocket) - except websockets.InvalidState: - print("InvalidState...") # 无效状态 - except Exception as e: - print("Exception:", e) - - -def asr(websocket): # ASR推理 - global inference_pipeline_asr, inference_pipeline_punc - # global param_dict_punc - global websocket_users - while websocket in websocket_users: - if not websocket.speek.empty(): - audio_in = websocket.speek.get() - websocket.speek.task_done() - if len(audio_in) > 0: - rec_result = inference_pipeline_asr(audio_in=audio_in) - if inference_pipeline_punc is not None and 'text' in rec_result: - rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc) - # print(rec_result) - if "text" in rec_result: - message = json.dumps({"mode": "offline", "text": rec_result["text"]}) - websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 - - time.sleep(0.1) - -def vad(data, websocket): # VAD推理 - global inference_pipeline_vad - #print(type(data)) - # print(param_dict_vad) - segments_result = inference_pipeline_vad(audio_in=data, param_dict=websocket.param_dict_vad) - # print(segments_result) - # print(param_dict_vad) - speech_start = False - speech_end = False - - if len(segments_result) == 0 or len(segments_result["text"]) > 1: - return speech_start, speech_end - if segments_result["text"][0][0] != -1: - speech_start = True - if segments_result["text"][0][1] != -1: - speech_end = True - return speech_start, speech_end - - -start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() \ No newline at end of file diff --git a/funasr/runtime/python/websocket/ASR_server_2pass.py b/funasr/runtime/python/websocket/ASR_server_2pass.py deleted file mode 100644 index 135a3cc34..000000000 --- a/funasr/runtime/python/websocket/ASR_server_2pass.py +++ /dev/null @@ -1,252 +0,0 @@ -import asyncio -import json -import websockets -import time -from queue import Queue -import threading -import argparse - -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.logger import get_logger -import logging -import tracemalloc -import numpy as np - -tracemalloc.start() - -logger = get_logger(log_level=logging.CRITICAL) -logger.setLevel(logging.CRITICAL) - - -websocket_users = set() #维护客户端列表 - -parser = argparse.ArgumentParser() -parser.add_argument("--host", - type=str, - default="0.0.0.0", - required=False, - help="host ip, localhost, 0.0.0.0") -parser.add_argument("--port", - type=int, - default=10095, - required=False, - help="grpc server port") -parser.add_argument("--asr_model", - type=str, - default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - help="model from modelscope") -parser.add_argument("--vad_model", - type=str, - default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - help="model from modelscope") - -parser.add_argument("--punc_model", - type=str, - default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", - help="model from modelscope") -parser.add_argument("--ngpu", - type=int, - default=1, - help="0 for cpu, 1 for gpu") - -args = parser.parse_args() - -print("model loading") - -def load_bytes(input): - middle_data = np.frombuffer(input, dtype=np.int16) - middle_data = np.asarray(middle_data) - if middle_data.dtype.kind not in 'iu': - raise TypeError("'middle_data' must be an array of integers") - dtype = np.dtype('float32') - if dtype.kind != 'f': - raise TypeError("'dtype' must be a floating point type") - - i = np.iinfo(middle_data.dtype) - abs_max = 2 ** (i.bits - 1) - offset = i.min + abs_max - array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) - return array - -# vad -inference_pipeline_vad = pipeline( - task=Tasks.voice_activity_detection, - model=args.vad_model, - model_revision=None, - output_dir=None, - batch_size=1, - mode='online', - ngpu=args.ngpu, -) -# param_dict_vad = {'in_cache': dict(), "is_final": False} - -# asr -param_dict_asr = {} -# param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开 -inference_pipeline_asr = pipeline( - task=Tasks.auto_speech_recognition, - model=args.asr_model, - param_dict=param_dict_asr, - ngpu=args.ngpu, -) -if args.punc_model != "": - # param_dict_punc = {'cache': list()} - inference_pipeline_punc = pipeline( - task=Tasks.punctuation, - model=args.punc_model, - model_revision=None, - ngpu=args.ngpu, - ) -else: - inference_pipeline_punc = None - - -inference_pipeline_asr_online = pipeline( - task=Tasks.auto_speech_recognition, - model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online', - model_revision=None) - - -print("model loaded") - - - -async def ws_serve(websocket, path): - #speek = Queue() - frames = [] # 存储所有的帧数据 - frames_online = [] # 存储所有的帧数据 - buffer = [] # 存储缓存中的帧数据(最多两个片段) - RECORD_NUM = 0 - global websocket_users - speech_start, speech_end = False, False - # 调用asr函数 - websocket.param_dict_vad = {'in_cache': dict(), "is_final": False} - websocket.param_dict_punc = {'cache': list()} - websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包 - websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端 - websocket_users.add(websocket) - ss = threading.Thread(target=asr, args=(websocket,)) - ss.start() - - websocket.param_dict_asr_online = {"cache": dict(), "is_final": False} - websocket.speek_online = Queue() # websocket 添加进队列对象 让asr读取语音数据包 - ss_online = threading.Thread(target=asr_online, args=(websocket,)) - ss_online.start() - - try: - async for message in websocket: - #voices.put(message) - #print("put") - #await websocket.send("123") - buffer.append(message) - if len(buffer) > 2: - buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 - - if speech_start: - frames.append(message) - frames_online.append(message) - RECORD_NUM += 1 - if RECORD_NUM % 6 == 0: - audio_in = b"".join(frames_online) - websocket.speek_online.put(audio_in) - frames_online = [] - - speech_start_i, speech_end_i = vad(message, websocket) - #print(speech_start_i, speech_end_i) - if speech_start_i: - RECORD_NUM += 1 - speech_start = speech_start_i - frames = [] - frames.extend(buffer) # 把之前2个语音数据快加入 - frames_online = [] - frames_online.append(message) - # frames_online.extend(buffer) - # RECORD_NUM += 1 - websocket.param_dict_asr_online["is_final"] = False - if speech_end_i or RECORD_NUM > 300: - speech_start = False - audio_in = b"".join(frames) - websocket.speek.put(audio_in) - frames = [] # 清空所有的帧数据 - frames_online = [] - websocket.param_dict_asr_online["is_final"] = True - buffer = [] # 清空缓存中的帧数据(最多两个片段) - RECORD_NUM = 0 - if not websocket.send_msg.empty(): - await websocket.send(websocket.send_msg.get()) - websocket.send_msg.task_done() - - - except websockets.ConnectionClosed: - print("ConnectionClosed...", websocket_users) # 链接断开 - websocket_users.remove(websocket) - except websockets.InvalidState: - print("InvalidState...") # 无效状态 - except Exception as e: - print("Exception:", e) - - -def asr(websocket): # ASR推理 - global inference_pipeline_asr - # global param_dict_punc - global websocket_users - while websocket in websocket_users: - if not websocket.speek.empty(): - audio_in = websocket.speek.get() - websocket.speek.task_done() - if len(audio_in) > 0: - rec_result = inference_pipeline_asr(audio_in=audio_in) - if inference_pipeline_punc is not None and 'text' in rec_result: - rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc) - # print(rec_result) - if "text" in rec_result: - message = json.dumps({"mode": "offline", "text": rec_result["text"]}) - websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 - - time.sleep(0.1) - - -def asr_online(websocket): # ASR推理 - global inference_pipeline_asr_online - # global param_dict_punc - global websocket_users - while websocket in websocket_users: - if not websocket.speek_online.empty(): - audio_in = websocket.speek_online.get() - websocket.speek_online.task_done() - if len(audio_in) > 0: - # print(len(audio_in)) - audio_in = load_bytes(audio_in) - # print(audio_in.shape) - rec_result = inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online) - - # print(rec_result) - if "text" in rec_result: - message = json.dumps({"mode": "online", "text": rec_result["text"]}) - websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 - - time.sleep(0.1) - -def vad(data, websocket): # VAD推理 - global inference_pipeline_vad, param_dict_vad - #print(type(data)) - # print(param_dict_vad) - segments_result = inference_pipeline_vad(audio_in=data, param_dict=websocket.param_dict_vad) - # print(segments_result) - # print(param_dict_vad) - speech_start = False - speech_end = False - - if len(segments_result) == 0 or len(segments_result["text"]) > 1: - return speech_start, speech_end - if segments_result["text"][0][0] != -1: - speech_start = True - if segments_result["text"][0][1] != -1: - speech_end = True - return speech_start, speech_end - - -start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() \ No newline at end of file diff --git a/funasr/runtime/python/websocket/ASR_server_streaming.py b/funasr/runtime/python/websocket/ASR_server_streaming.py deleted file mode 100644 index b7c54f78c..000000000 --- a/funasr/runtime/python/websocket/ASR_server_streaming.py +++ /dev/null @@ -1,261 +0,0 @@ -import asyncio -import json -import websockets -import time -from queue import Queue -import threading -import argparse - -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.logger import get_logger -import logging -import tracemalloc -import numpy as np - -tracemalloc.start() - -logger = get_logger(log_level=logging.CRITICAL) -logger.setLevel(logging.CRITICAL) - - -websocket_users = set() #维护客户端列表 - -parser = argparse.ArgumentParser() -parser.add_argument("--host", - type=str, - default="0.0.0.0", - required=False, - help="host ip, localhost, 0.0.0.0") -parser.add_argument("--port", - type=int, - default=10095, - required=False, - help="grpc server port") -parser.add_argument("--asr_model", - type=str, - default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - help="model from modelscope") -parser.add_argument("--vad_model", - type=str, - default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - help="model from modelscope") - -parser.add_argument("--punc_model", - type=str, - default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", - help="model from modelscope") -parser.add_argument("--ngpu", - type=int, - default=1, - help="0 for cpu, 1 for gpu") - -args = parser.parse_args() - -print("model loading") - -def load_bytes(input): - middle_data = np.frombuffer(input, dtype=np.int16) - middle_data = np.asarray(middle_data) - if middle_data.dtype.kind not in 'iu': - raise TypeError("'middle_data' must be an array of integers") - dtype = np.dtype('float32') - if dtype.kind != 'f': - raise TypeError("'dtype' must be a floating point type") - - i = np.iinfo(middle_data.dtype) - abs_max = 2 ** (i.bits - 1) - offset = i.min + abs_max - array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) - return array - -# vad -inference_pipeline_vad = pipeline( - task=Tasks.voice_activity_detection, - model=args.vad_model, - model_revision=None, - output_dir=None, - batch_size=1, - mode='online', - ngpu=args.ngpu, -) -# param_dict_vad = {'in_cache': dict(), "is_final": False} - -# # asr -# param_dict_asr = {} -# # param_dict["hotword"] = "小五 小五月" # 设置热词,用空格隔开 -# inference_pipeline_asr = pipeline( -# task=Tasks.auto_speech_recognition, -# model=args.asr_model, -# param_dict=param_dict_asr, -# ngpu=args.ngpu, -# ) -# if args.punc_model != "": -# # param_dict_punc = {'cache': list()} -# inference_pipeline_punc = pipeline( -# task=Tasks.punctuation, -# model=args.punc_model, -# model_revision=None, -# ngpu=args.ngpu, -# ) -# else: -# inference_pipeline_punc = None - - -inference_pipeline_asr_online = pipeline( - task=Tasks.auto_speech_recognition, - model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online', - model_revision=None) - - -print("model loaded") - - - -async def ws_serve(websocket, path): - #speek = Queue() - frames = [] # 存储所有的帧数据 - frames_online = [] # 存储所有的帧数据 - buffer = [] # 存储缓存中的帧数据(最多两个片段) - RECORD_NUM = 0 - global websocket_users - speech_start, speech_end = False, False - # 调用asr函数 - websocket.param_dict_vad = {'in_cache': dict(), "is_final": False} - websocket.param_dict_punc = {'cache': list()} - websocket.speek = Queue() #websocket 添加进队列对象 让asr读取语音数据包 - websocket.send_msg = Queue() #websocket 添加个队列对象 让ws发送消息到客户端 - websocket_users.add(websocket) - # ss = threading.Thread(target=asr, args=(websocket,)) - # ss.start() - - websocket.param_dict_asr_online = {"cache": dict(), "is_final": False} - websocket.speek_online = Queue() # websocket 添加进队列对象 让asr读取语音数据包 - ss_online = threading.Thread(target=asr_online, args=(websocket,)) - ss_online.start() - - try: - async for data in websocket: - #voices.put(message) - #print("put") - #await websocket.send("123") - - data = json.loads(data) - # message = data["data"] - message = bytes(data['audio'], 'ISO-8859-1') - chunk = data["chunk"] - chunk_num = 600//chunk - is_speaking = data["is_speaking"] - websocket.param_dict_vad["is_final"] = not is_speaking - buffer.append(message) - if len(buffer) > 2: - buffer.pop(0) # 如果缓存超过两个片段,则删除最早的一个 - - if speech_start: - # frames.append(message) - frames_online.append(message) - # RECORD_NUM += 1 - if len(frames_online) % chunk_num == 0: - audio_in = b"".join(frames_online) - websocket.speek_online.put(audio_in) - frames_online = [] - - speech_start_i, speech_end_i = vad(message, websocket) - #print(speech_start_i, speech_end_i) - if speech_start_i: - # RECORD_NUM += 1 - speech_start = speech_start_i - # frames = [] - # frames.extend(buffer) # 把之前2个语音数据快加入 - frames_online = [] - # frames_online.append(message) - frames_online.extend(buffer) - # RECORD_NUM += 1 - websocket.param_dict_asr_online["is_final"] = False - if speech_end_i: - speech_start = False - # audio_in = b"".join(frames) - # websocket.speek.put(audio_in) - # frames = [] # 清空所有的帧数据 - frames_online = [] - websocket.param_dict_asr_online["is_final"] = True - # buffer = [] # 清空缓存中的帧数据(最多两个片段) - # RECORD_NUM = 0 - if not websocket.send_msg.empty(): - await websocket.send(websocket.send_msg.get()) - websocket.send_msg.task_done() - - - except websockets.ConnectionClosed: - print("ConnectionClosed...", websocket_users) # 链接断开 - websocket_users.remove(websocket) - except websockets.InvalidState: - print("InvalidState...") # 无效状态 - except Exception as e: - print("Exception:", e) - - -# def asr(websocket): # ASR推理 -# global inference_pipeline_asr -# # global param_dict_punc -# global websocket_users -# while websocket in websocket_users: -# if not websocket.speek.empty(): -# audio_in = websocket.speek.get() -# websocket.speek.task_done() -# if len(audio_in) > 0: -# rec_result = inference_pipeline_asr(audio_in=audio_in) -# if inference_pipeline_punc is not None and 'text' in rec_result: -# rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=websocket.param_dict_punc) -# # print(rec_result) -# if "text" in rec_result: -# message = json.dumps({"mode": "offline", "text": rec_result["text"]}) -# websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 -# -# time.sleep(0.1) - - -def asr_online(websocket): # ASR推理 - global inference_pipeline_asr_online - # global param_dict_punc - global websocket_users - while websocket in websocket_users: - if not websocket.speek_online.empty(): - audio_in = websocket.speek_online.get() - websocket.speek_online.task_done() - if len(audio_in) > 0: - # print(len(audio_in)) - audio_in = load_bytes(audio_in) - # print(audio_in.shape) - rec_result = inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online) - - # print(rec_result) - if "text" in rec_result: - if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice": - message = json.dumps({"mode": "online", "text": rec_result["text"]}) - websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 - - time.sleep(0.1) - -def vad(data, websocket): # VAD推理 - global inference_pipeline_vad, param_dict_vad - #print(type(data)) - # print(param_dict_vad) - segments_result = inference_pipeline_vad(audio_in=data, param_dict=websocket.param_dict_vad) - # print(segments_result) - # print(param_dict_vad) - speech_start = False - speech_end = False - - if len(segments_result) == 0 or len(segments_result["text"]) > 1: - return speech_start, speech_end - if segments_result["text"][0][0] != -1: - speech_start = True - if segments_result["text"][0][1] != -1: - speech_end = True - return speech_start, speech_end - - -start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() \ No newline at end of file diff --git a/funasr/runtime/python/websocket/ASR_server_streaming_asr.py b/funasr/runtime/python/websocket/ASR_server_streaming_asr.py deleted file mode 100644 index b8b8b8d50..000000000 --- a/funasr/runtime/python/websocket/ASR_server_streaming_asr.py +++ /dev/null @@ -1,161 +0,0 @@ -import asyncio -import json -import websockets -import time -from queue import Queue -import threading -import argparse - -from modelscope.pipelines import pipeline -from modelscope.utils.constant import Tasks -from modelscope.utils.logger import get_logger -import logging -import tracemalloc -import numpy as np - -tracemalloc.start() - -logger = get_logger(log_level=logging.CRITICAL) -logger.setLevel(logging.CRITICAL) - - -websocket_users = set() #维护客户端列表 - -parser = argparse.ArgumentParser() -parser.add_argument("--host", - type=str, - default="0.0.0.0", - required=False, - help="host ip, localhost, 0.0.0.0") -parser.add_argument("--port", - type=int, - default=10095, - required=False, - help="grpc server port") -parser.add_argument("--asr_model", - type=str, - default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", - help="model from modelscope") -parser.add_argument("--vad_model", - type=str, - default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - help="model from modelscope") - -parser.add_argument("--punc_model", - type=str, - default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", - help="model from modelscope") -parser.add_argument("--ngpu", - type=int, - default=1, - help="0 for cpu, 1 for gpu") - -args = parser.parse_args() - -print("model loading") - -def load_bytes(input): - middle_data = np.frombuffer(input, dtype=np.int16) - middle_data = np.asarray(middle_data) - if middle_data.dtype.kind not in 'iu': - raise TypeError("'middle_data' must be an array of integers") - dtype = np.dtype('float32') - if dtype.kind != 'f': - raise TypeError("'dtype' must be a floating point type") - - i = np.iinfo(middle_data.dtype) - abs_max = 2 ** (i.bits - 1) - offset = i.min + abs_max - array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) - return array - -inference_pipeline_asr_online = pipeline( - task=Tasks.auto_speech_recognition, - # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online', - model='damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8404-online', - model_revision=None) - - -print("model loaded") - - - -async def ws_serve(websocket, path): - frames_online = [] - global websocket_users - websocket.send_msg = Queue() - websocket_users.add(websocket) - websocket.param_dict_asr_online = {"cache": dict()} - websocket.speek_online = Queue() - ss_online = threading.Thread(target=asr_online, args=(websocket,)) - ss_online.start() - ss_ws_send = threading.Thread(target=ws_send, args=(websocket,)) - ss_ws_send.start() - try: - async for message in websocket: - message = json.loads(message) - audio = bytes(message['audio'], 'ISO-8859-1') - chunk = message["chunk"] - chunk_num = 500//chunk - is_speaking = message["is_speaking"] - websocket.param_dict_asr_online["is_final"] = not is_speaking - frames_online.append(audio) - - if len(frames_online) % chunk_num == 0 or not is_speaking: - audio_in = b"".join(frames_online) - websocket.speek_online.put(audio_in) - frames_online = [] - - # if not websocket.send_msg.empty(): - # await websocket.send(websocket.send_msg.get()) - # websocket.send_msg.task_done() - - - except websockets.ConnectionClosed: - print("ConnectionClosed...", websocket_users) # 链接断开 - websocket_users.remove(websocket) - except websockets.InvalidState: - print("InvalidState...") # 无效状态 - except Exception as e: - print("Exception:", e) - - - -def ws_send(websocket): # ASR推理 - global inference_pipeline_asr_online - global websocket_users - while websocket in websocket_users: - if not websocket.speek_online.empty(): - await websocket.send(websocket.send_msg.get()) - websocket.send_msg.task_done() - time.sleep(0.005) - - -def asr_online(websocket): # ASR推理 - global websocket_users - while websocket in websocket_users: - if not websocket.send_msg.empty(): - audio_in = websocket.speek_online.get() - websocket.speek_online.task_done() - if len(audio_in) > 0: - # print(len(audio_in)) - audio_in = load_bytes(audio_in) - # print(audio_in.shape) - print(websocket.param_dict_asr_online["is_final"]) - rec_result = inference_pipeline_asr_online(audio_in=audio_in, - param_dict=websocket.param_dict_asr_online) - if websocket.param_dict_asr_online["is_final"]: - websocket.param_dict_asr_online["cache"] = dict() - - print(rec_result) - if "text" in rec_result: - if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice": - message = json.dumps({"mode": "online", "text": rec_result["text"]}) - websocket.send_msg.put(message) # 存入发送队列 直接调用send发送不了 - - time.sleep(0.005) - - -start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) -asyncio.get_event_loop().run_until_complete(start_server) -asyncio.get_event_loop().run_forever() \ No newline at end of file diff --git a/funasr/runtime/python/websocket/parse_args.py b/funasr/runtime/python/websocket/parse_args.py new file mode 100644 index 000000000..2528a7624 --- /dev/null +++ b/funasr/runtime/python/websocket/parse_args.py @@ -0,0 +1,35 @@ +# -*- encoding: utf-8 -*- +import argparse +parser = argparse.ArgumentParser() +parser.add_argument("--host", + type=str, + default="0.0.0.0", + required=False, + help="host ip, localhost, 0.0.0.0") +parser.add_argument("--port", + type=int, + default=10095, + required=False, + help="grpc server port") +parser.add_argument("--asr_model", + type=str, + default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", + help="model from modelscope") +parser.add_argument("--asr_model_online", + type=str, + default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", + help="model from modelscope") +parser.add_argument("--vad_model", + type=str, + default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", + help="model from modelscope") +parser.add_argument("--punc_model", + type=str, + default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", + help="model from modelscope") +parser.add_argument("--ngpu", + type=int, + default=1, + help="0 for cpu, 1 for gpu") + +args = parser.parse_args() \ No newline at end of file diff --git a/funasr/runtime/python/websocket/ASR_client.py b/funasr/runtime/python/websocket/ws_client.py similarity index 73% rename from funasr/runtime/python/websocket/ASR_client.py rename to funasr/runtime/python/websocket/ws_client.py index 9a4a14802..8bbf1032d 100644 --- a/funasr/runtime/python/websocket/ASR_client.py +++ b/funasr/runtime/python/websocket/ws_client.py @@ -1,4 +1,5 @@ # -*- encoding: utf-8 -*- +import os import time import websockets import asyncio @@ -18,29 +19,36 @@ parser.add_argument("--port", required=False, help="grpc server port") parser.add_argument("--chunk_size", + type=str, + default="5, 10, 5", + help="chunk") +parser.add_argument("--chunk_interval", type=int, - default=300, - help="ms") + default=10, + help="chunk") parser.add_argument("--audio_in", type=str, default=None, help="audio_in") args = parser.parse_args() +args.chunk_size = [int(x) for x in args.chunk_size.split(",")] # voices = asyncio.Queue() from queue import Queue voices = Queue() - + # 其他函数可以通过调用send(data)来发送数据,例如: async def record_microphone(): + is_finished = False import pyaudio #print("2") global voices FORMAT = pyaudio.paInt16 CHANNELS = 1 RATE = 16000 - CHUNK = int(RATE / 1000 * args.chunk_size) + chunk_size = 60*args.chunk_size[1]/args.chunk_interval + CHUNK = int(RATE / 1000 * chunk_size) p = pyaudio.PyAudio() @@ -54,7 +62,7 @@ async def record_microphone(): data = stream.read(CHUNK) data = data.decode('ISO-8859-1') - message = json.dumps({"chunk": args.chunk_size, "is_speaking": is_speaking, "audio": data}) + message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio": data, "is_speaking": is_speaking, "is_finished": is_finished}) voices.put(message) #print(voices.qsize()) @@ -65,6 +73,7 @@ async def record_microphone(): async def record_from_scp(): import wave global voices + is_finished = False if args.audio_in.endswith(".scp"): f_scp = open(args.audio_in) wavs = f_scp.readlines() @@ -86,9 +95,10 @@ async def record_from_scp(): # 将音频帧数据转换为字节类型的数据 audio_bytes = bytes(frames) - stride = int(args.chunk_size/1000*16000*2) + # stride = int(args.chunk_size/1000*16000*2) + stride = int(60*args.chunk_size[1]/args.chunk_interval/1000*16000*2) chunk_num = (len(audio_bytes)-1)//stride + 1 - print(stride) + # print(stride) is_speaking = True for i in range(chunk_num): if i == chunk_num-1: @@ -96,13 +106,16 @@ async def record_from_scp(): beg = i*stride data = audio_bytes[beg:beg+stride] data = data.decode('ISO-8859-1') - message = json.dumps({"chunk": args.chunk_size, "is_speaking": is_speaking, "audio": data}) + message = json.dumps({"chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "is_speaking": is_speaking, "audio": data, "is_finished": is_finished}) voices.put(message) # print("data_chunk: ", len(data_chunk)) # print(voices.qsize()) - await asyncio.sleep(args.chunk_size/1000) - + await asyncio.sleep(60*args.chunk_size[1]/args.chunk_interval/1000) + + is_finished = True + message = json.dumps({"is_finished": is_finished}) + voices.put(message) async def ws_send(): global voices @@ -122,6 +135,24 @@ async def ws_send(): async def message(): + global websocket + text_print = "" + while True: + try: + meg = await websocket.recv() + meg = json.loads(meg) + # print(meg, end = '') + # print("\r") + text = meg["text"][0] + text_print += text + text_print = text_print[-55:] + os.system('clear') + print("\r"+text_print) + except Exception as e: + print("Exception:", e) + + +async def print_messge(): global websocket while True: try: @@ -129,8 +160,7 @@ async def message(): meg = json.loads(meg) print(meg) except Exception as e: - print("Exception:", e) - + print("Exception:", e) async def ws_client(): diff --git a/funasr/runtime/python/websocket/ws_server_online.py b/funasr/runtime/python/websocket/ws_server_online.py new file mode 100644 index 000000000..7ef0e2125 --- /dev/null +++ b/funasr/runtime/python/websocket/ws_server_online.py @@ -0,0 +1,108 @@ +import asyncio +import json +import websockets +import time +from queue import Queue +import threading +import logging +import tracemalloc +import numpy as np + +from parse_args import args +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from funasr_onnx.utils.frontend import load_bytes + +tracemalloc.start() + +logger = get_logger(log_level=logging.CRITICAL) +logger.setLevel(logging.CRITICAL) + + +websocket_users = set() + + +print("model loading") + +inference_pipeline_asr_online = pipeline( + task=Tasks.auto_speech_recognition, + model=args.asr_model_online, + model_revision='v1.0.4') + +print("model loaded") + + + +async def ws_serve(websocket, path): + frames_online = [] + global websocket_users + websocket.send_msg = Queue() + websocket_users.add(websocket) + websocket.param_dict_asr_online = {"cache": dict()} + websocket.speek_online = Queue() + ss_online = threading.Thread(target=asr_online, args=(websocket,)) + ss_online.start() + + try: + async for message in websocket: + message = json.loads(message) + is_finished = message["is_finished"] + if not is_finished: + audio = bytes(message['audio'], 'ISO-8859-1') + + is_speaking = message["is_speaking"] + websocket.param_dict_asr_online["is_final"] = not is_speaking + + websocket.param_dict_asr_online["chunk_size"] = message["chunk_size"] + + + frames_online.append(audio) + + if len(frames_online) % message["chunk_interval"] == 0 or not is_speaking: + + audio_in = b"".join(frames_online) + websocket.speek_online.put(audio_in) + frames_online = [] + + if not websocket.send_msg.empty(): + await websocket.send(websocket.send_msg.get()) + websocket.send_msg.task_done() + + + except websockets.ConnectionClosed: + print("ConnectionClosed...", websocket_users) # 链接断开 + websocket_users.remove(websocket) + except websockets.InvalidState: + print("InvalidState...") # 无效状态 + except Exception as e: + print("Exception:", e) + + + +def asr_online(websocket): # ASR推理 + global websocket_users + while websocket in websocket_users: + if not websocket.speek_online.empty(): + audio_in = websocket.speek_online.get() + websocket.speek_online.task_done() + if len(audio_in) > 0: + # print(len(audio_in)) + audio_in = load_bytes(audio_in) + rec_result = inference_pipeline_asr_online(audio_in=audio_in, + param_dict=websocket.param_dict_asr_online) + if websocket.param_dict_asr_online["is_final"]: + websocket.param_dict_asr_online["cache"] = dict() + + if "text" in rec_result: + if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice": + print(rec_result["text"]) + message = json.dumps({"mode": "online", "text": rec_result["text"]}) + websocket.send_msg.put(message) + + time.sleep(0.005) + + +start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None) +asyncio.get_event_loop().run_until_complete(start_server) +asyncio.get_event_loop().run_forever() \ No newline at end of file From 074bd57273f9cbe37daea7a8a744e760fcaee19f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 27 Apr 2023 17:12:16 +0800 Subject: [PATCH 14/15] websocket --- funasr/runtime/python/websocket/README.md | 33 ++++++++++++++++------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/funasr/runtime/python/websocket/README.md b/funasr/runtime/python/websocket/README.md index ba7230ae7..723782f74 100644 --- a/funasr/runtime/python/websocket/README.md +++ b/funasr/runtime/python/websocket/README.md @@ -5,7 +5,7 @@ The audio data is in streaming, the asr inference process is in offline. ## For the Server -Install the modelscope and funasr +### Install the modelscope and funasr ```shell pip install -U modelscope funasr @@ -14,23 +14,34 @@ pip install -U modelscope funasr git clone https://github.com/alibaba/FunASR.git && cd FunASR ``` -Install the requirements for server +### Install the requirements for server ```shell cd funasr/runtime/python/websocket pip install -r requirements_server.txt ``` -Start server +### Start server +#### ASR offline server -```shell -python ASR_server.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" -``` -For the paraformer 2pass model +[//]: # (```shell) +[//]: # (python ws_server_online.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") + +[//]: # (```) +#### ASR streaming server ```shell -python ASR_server_2pass.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +python ws_server_online.py --host "0.0.0.0" --port 10095 ``` +#### + +#### ASR offline/online 2pass server + +[//]: # (```shell) + +[//]: # (python ws_server_online.py --host "0.0.0.0" --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch") + +[//]: # (```) ## For the client @@ -44,8 +55,10 @@ pip install -r requirements_client.txt Start client ```shell -python ASR_client.py --host "127.0.0.1" --port 10095 --chunk_size 50 +# --chunk_size, "5,10,5"=600ms, "8,8,4"=480ms +python ws_client.py --host "127.0.0.1" --port 10096 --chunk_size "5,10,5" ``` ## Acknowledge -1. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service. +1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR). +2. We acknowledge [cgisky1980](https://github.com/cgisky1980/FunASR) for contributing the websocket service. From d9db469ed4cddff2f24d6df5f22b1175d43061cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 27 Apr 2023 17:22:27 +0800 Subject: [PATCH 15/15] websocket --- .gitignore | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index b0fa5430b..33b8c3979 100644 --- a/.gitignore +++ b/.gitignore @@ -16,5 +16,4 @@ MaaS-lib .egg* dist build -funasr.egg-info -sherpa \ No newline at end of file +funasr.egg-info \ No newline at end of file