diff --git a/runtime/python/websocket/funasr_wss_server_llm.py b/runtime/python/websocket/funasr_wss_server_llm.py index 2c7fe6a57..bd3ffde40 100644 --- a/runtime/python/websocket/funasr_wss_server_llm.py +++ b/runtime/python/websocket/funasr_wss_server_llm.py @@ -27,9 +27,9 @@ class NlsTtsSynthesizer: self.token = token self.appkey = appkey self.sdk = None - self.init_sdk() self.started = False self.count = 0 + self.init_sdk() def init_sdk(self): # 配置回调函数 @@ -49,7 +49,6 @@ class NlsTtsSynthesizer: def on_data(self, data, *args): self.count += len(data) - print(f"cout: {self.count}") self.tts_fifo.append(data) # with open('tts_server.pcm', 'ab') as file: # file.write(data) @@ -64,25 +63,24 @@ class NlsTtsSynthesizer: print("on sentence end =>{}".format(message)) def on_completed(self, message, *args): + print("on message data cout: =>{}".format(self.count)) print("on completed =>{}".format(message)) + self.started = False def on_error(self, message, *args): print("on_error args=>{}".format(args)) def on_close(self, *args): print("on_close: args=>{}".format(args)) - print("on message data cout: =>{}".format(self.count)) - self.started = False def start(self): self.sdk.startStreamInputTts() self.started = True def send_text(self, text): - print(f"text: {text}") self.sdk.sendStreamInputTts(text) - async def stop(self): + def stop(self): self.sdk.stopStreamInputTts() @@ -137,7 +135,6 @@ websocket_users = set() print("model loading") from funasr import AutoModel - # vad model_vad = AutoModel( model=args.vad_model, @@ -150,7 +147,6 @@ model_vad = AutoModel( # chunk_size=60, ) - import os # from install_model_requirements import install_requirements @@ -239,7 +235,7 @@ model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer} async def send_to_client(websocket, syntheszier, tts_fifo): # Sending tts data to the client while True: - if websocket.open and syntheszier.started: + if websocket.open and (syntheszier.started or len(tts_fifo) > 0): try: if len(tts_fifo) > 0: await websocket.send(tts_fifo.popleft()) @@ -252,6 +248,10 @@ async def send_to_client(websocket, syntheszier, tts_fifo): break +def tts_sync_thread(coro): + asyncio.run(coro) + + async def model_inference( websocket, audio_in, @@ -317,6 +317,10 @@ async def model_inference( thread.start() res = "" beg_llm = time.time() + tts_thread = Thread( + target=tts_sync_thread, args=(send_to_client(websocket, synthesizer, fifo_queue),) + ) + tts_thread.start() for new_text in streamer: end_llm = time.time() print( @@ -342,14 +346,10 @@ async def model_inference( ) # print(f"online: {message}") await websocket.send(message) - if len(fifo_queue) > 0: - while len(fifo_queue) > 0: - await websocket.send(fifo_queue.popleft()) - # synthesizer.send_text(res) - tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue)) synthesizer.stop() - await tts_to_client_task + # await tts_to_client_task + tts_thread.join() mode = "2pass-offline" message = json.dumps( {