Merge branch dev_lr_deepspeed into dev_gzf_deepspeed

Title: Add llm tts to client process. 

本次代码评审主要增强了 WebSocket 服务端和客户端的语音合成(TTS)功能,添加了语音数据发送、接收处理和计数逻辑,优化了模块结构,引入了`NlsTtsSynthesizer`类来管理语音合成流程,并调整了错误处理和连接管理,使得语音传输更稳定且可追踪。
Link: https://code.alibaba-inc.com/zhifu.gzf/FunASR/codereview/18178626
This commit is contained in:
zhifu.gzf 2024-09-02 10:22:55 +08:00
commit 623fd16f34
2 changed files with 103 additions and 6 deletions

View File

@ -238,7 +238,7 @@ async def record_from_scp(chunk_begin, chunk_size):
while not offline_msg_done:
await asyncio.sleep(1)
await websocket.close()
#await websocket.close()
async def message(id):
@ -246,6 +246,7 @@ async def message(id):
text_print = ""
text_print_2pass_online = ""
text_print_2pass_offline = ""
tts_count = 0
if args.output_dir is not None:
ibest_writer = open(
os.path.join(args.output_dir, "text.{}".format(id)), "a", encoding="utf-8"
@ -253,9 +254,18 @@ async def message(id):
else:
ibest_writer = None
try:
timestamp = int(time.time())
file_name = f'tts_client_{timestamp}.pcm'
file = open(file_name, 'wb')
while True:
meg = await websocket.recv()
if isinstance(meg, bytes):
try:
tts_count += len(meg)
file.write(meg)
except Exception as e:
print(e)
continue
meg = json.loads(meg)
wav_name = meg.get("wav_name", "demo")
text = meg["text"]
@ -296,8 +306,9 @@ async def message(id):
text_print_2pass_online = ""
text_print = text_print_2pass_offline + "{}".format(text)
text_print_2pass_offline += "{}".format(text)
text_print = text_print[-args.words_max_print :]
#text_print = text_print[-args.words_max_print :]
os.system("clear")
print("tts_count len: =>{}".format(tts_count))
print("\rpid" + str(id) + ": " + text_print)
# offline_msg_done=True
@ -305,8 +316,6 @@ async def message(id):
print("Exception:", e)
# traceback.print_exc()
# await websocket.close()
async def ws_client(id, chunk_begin, chunk_size):
if args.audio_in is None:
chunk_begin = 0

View File

@ -7,6 +7,71 @@ import tracemalloc
import numpy as np
import argparse
import ssl
import nls
from collections import deque
import threading
class NlsTtsSynthesizer:
def __init__(self, websocket, tts_fifo, token, appkey, url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1"):
self.websocket = websocket
self.tts_fifo = tts_fifo
self.url = url
self.token = token
self.appkey = appkey
self.sdk = None
self.init_sdk()
self.started = False
self.count = 0
def init_sdk(self):
# 配置回调函数
self.sdk = nls.NlsStreamInputTtsSynthesizer(
url=self.url,
token=self.token,
appkey=self.appkey,
on_data=self.on_data,
on_sentence_begin=self.on_sentence_begin,
on_sentence_synthesis=self.on_sentence_synthesis,
on_sentence_end=self.on_sentence_end,
on_completed=self.on_completed,
on_error=self.on_error,
on_close=self.on_close,
callback_args=[]
)
def on_data(self, data, *args):
self.count += len(data)
self.tts_fifo.append(data)
#with open('tts_server.pcm', 'ab') as file:
# file.write(data)
def on_sentence_begin(self, message, *args):
print('on sentence begin =>{}'.format(message))
def on_sentence_synthesis(self, message, *args):
print('on sentence synthesis =>{}'.format(message))
def on_sentence_end(self, message, *args):
print('on sentence end =>{}'.format(message))
def on_completed(self, message, *args):
print('on completed =>{}'.format(message))
def on_error(self, message, *args):
print('on_error args=>{}'.format(args))
def on_close(self, *args):
print('on_close: args=>{}'.format(args))
print("on message data cout: =>{}".format(self.count))
self.started = False
def start(self):
self.sdk.startStreamInputTts()
self.started = True
def send_text(self, text):
self.sdk.sendStreamInputTts(text)
async def stop(self):
self.sdk.stopStreamInputTts()
parser = argparse.ArgumentParser()
parser.add_argument(
@ -181,7 +246,20 @@ tokenizer = model_llm.kwargs["tokenizer"]
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
async def send_to_client(websocket, syntheszier, tts_fifo):
# Sending tts data to the client
while True:
if websocket.open and syntheszier.started:
try:
if len(tts_fifo) > 0:
await websocket.send(tts_fifo.popleft())
else:
await asyncio.sleep(0.01)
except Exception as e:
print(f"Error sending data to client: {e}")
else:
print("WebSocket connection is not open or syntheszier is not started.")
break
async def model_inference(
websocket,
audio_in,
@ -192,6 +270,9 @@ async def model_inference(
history=None,
text_usr="",
):
fifo_queue = deque()
synthesizer = NlsTtsSynthesizer(websocket=websocket, tts_fifo=fifo_queue, token="xxx", appkey="xxx")
synthesizer.start()
beg0 = time.time()
if his_state is None:
his_state = model_dict
@ -248,6 +329,7 @@ async def model_inference(
f"generated new text {new_text}, time_fr_receive: {end_llm - beg0:.2f}, time_llm_decode: {end_llm - beg_llm:.2f}"
)
if len(new_text) > 0:
synthesizer.send_text(new_text)
res += new_text.replace("<|im_end|>", "")
contents_i[-1]["content"] = res
websocket.llm_state["contents_i"] = contents_i
@ -264,7 +346,13 @@ async def model_inference(
)
# print(f"online: {message}")
await websocket.send(message)
if len(fifo_queue) > 0:
while len(fifo_queue) > 0:
await websocket.send(fifo_queue.popleft())
tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue))
synthesizer.stop()
await tts_to_client_task
mode = "2pass-offline"
message = json.dumps(
{