Merge branch 'dev_gzf_deepspeed' of http://gitlab.alibaba-inc.com/zhifu.gzf/FunASR into dev_gzf_deepspeed

This commit is contained in:
木守 2024-09-03 19:45:35 +08:00
commit 245ba8fd45

View File

@ -27,9 +27,9 @@ class NlsTtsSynthesizer:
self.token = token
self.appkey = appkey
self.sdk = None
self.init_sdk()
self.started = False
self.count = 0
self.init_sdk()
def init_sdk(self):
# 配置回调函数
@ -49,7 +49,6 @@ class NlsTtsSynthesizer:
def on_data(self, data, *args):
self.count += len(data)
print(f"cout: {self.count}")
self.tts_fifo.append(data)
# with open('tts_server.pcm', 'ab') as file:
# file.write(data)
@ -64,25 +63,24 @@ class NlsTtsSynthesizer:
print("on sentence end =>{}".format(message))
def on_completed(self, message, *args):
print("on message data cout: =>{}".format(self.count))
print("on completed =>{}".format(message))
self.started = False
def on_error(self, message, *args):
print("on_error args=>{}".format(args))
def on_close(self, *args):
print("on_close: args=>{}".format(args))
print("on message data cout: =>{}".format(self.count))
self.started = False
def start(self):
self.sdk.startStreamInputTts()
self.started = True
def send_text(self, text):
print(f"text: {text}")
self.sdk.sendStreamInputTts(text)
async def stop(self):
def stop(self):
self.sdk.stopStreamInputTts()
@ -137,7 +135,6 @@ websocket_users = set()
print("model loading")
from funasr import AutoModel
# vad
model_vad = AutoModel(
model=args.vad_model,
@ -150,7 +147,6 @@ model_vad = AutoModel(
# chunk_size=60,
)
import os
# from install_model_requirements import install_requirements
@ -239,7 +235,7 @@ model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
async def send_to_client(websocket, syntheszier, tts_fifo):
# Sending tts data to the client
while True:
if websocket.open and syntheszier.started:
if websocket.open and (syntheszier.started or len(tts_fifo) > 0):
try:
if len(tts_fifo) > 0:
await websocket.send(tts_fifo.popleft())
@ -252,6 +248,10 @@ async def send_to_client(websocket, syntheszier, tts_fifo):
break
def tts_sync_thread(coro):
asyncio.run(coro)
async def model_inference(
websocket,
audio_in,
@ -317,6 +317,10 @@ async def model_inference(
thread.start()
res = ""
beg_llm = time.time()
tts_thread = Thread(
target=tts_sync_thread, args=(send_to_client(websocket, synthesizer, fifo_queue),)
)
tts_thread.start()
for new_text in streamer:
end_llm = time.time()
print(
@ -342,14 +346,10 @@ async def model_inference(
)
# print(f"online: {message}")
await websocket.send(message)
if len(fifo_queue) > 0:
while len(fifo_queue) > 0:
await websocket.send(fifo_queue.popleft())
# synthesizer.send_text(res)
tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue))
synthesizer.stop()
await tts_to_client_task
# await tts_to_client_task
tts_thread.join()
mode = "2pass-offline"
message = json.dumps(
{