mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch dev_lr_deepspeed into dev_gzf_deepspeed
Title: Add llm tts to client process. 本次代码评审主要增强了 WebSocket 服务端和客户端的语音合成(TTS)功能,添加了语音数据发送、接收处理和计数逻辑,优化了模块结构,引入了`NlsTtsSynthesizer`类来管理语音合成流程,并调整了错误处理和连接管理,使得语音传输更稳定且可追踪。 Link: https://code.alibaba-inc.com/zhifu.gzf/FunASR/codereview/18178626
This commit is contained in:
commit
623fd16f34
@ -238,7 +238,7 @@ async def record_from_scp(chunk_begin, chunk_size):
|
||||
while not offline_msg_done:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await websocket.close()
|
||||
#await websocket.close()
|
||||
|
||||
|
||||
async def message(id):
|
||||
@ -246,6 +246,7 @@ async def message(id):
|
||||
text_print = ""
|
||||
text_print_2pass_online = ""
|
||||
text_print_2pass_offline = ""
|
||||
tts_count = 0
|
||||
if args.output_dir is not None:
|
||||
ibest_writer = open(
|
||||
os.path.join(args.output_dir, "text.{}".format(id)), "a", encoding="utf-8"
|
||||
@ -253,9 +254,18 @@ async def message(id):
|
||||
else:
|
||||
ibest_writer = None
|
||||
try:
|
||||
timestamp = int(time.time())
|
||||
file_name = f'tts_client_{timestamp}.pcm'
|
||||
file = open(file_name, 'wb')
|
||||
while True:
|
||||
|
||||
meg = await websocket.recv()
|
||||
if isinstance(meg, bytes):
|
||||
try:
|
||||
tts_count += len(meg)
|
||||
file.write(meg)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
meg = json.loads(meg)
|
||||
wav_name = meg.get("wav_name", "demo")
|
||||
text = meg["text"]
|
||||
@ -296,8 +306,9 @@ async def message(id):
|
||||
text_print_2pass_online = ""
|
||||
text_print = text_print_2pass_offline + "{}".format(text)
|
||||
text_print_2pass_offline += "{}".format(text)
|
||||
text_print = text_print[-args.words_max_print :]
|
||||
#text_print = text_print[-args.words_max_print :]
|
||||
os.system("clear")
|
||||
print("tts_count len: =>{}".format(tts_count))
|
||||
print("\rpid" + str(id) + ": " + text_print)
|
||||
# offline_msg_done=True
|
||||
|
||||
@ -305,8 +316,6 @@ async def message(id):
|
||||
print("Exception:", e)
|
||||
# traceback.print_exc()
|
||||
# await websocket.close()
|
||||
|
||||
|
||||
async def ws_client(id, chunk_begin, chunk_size):
|
||||
if args.audio_in is None:
|
||||
chunk_begin = 0
|
||||
|
||||
@ -7,6 +7,71 @@ import tracemalloc
|
||||
import numpy as np
|
||||
import argparse
|
||||
import ssl
|
||||
import nls
|
||||
from collections import deque
|
||||
import threading
|
||||
|
||||
class NlsTtsSynthesizer:
|
||||
def __init__(self, websocket, tts_fifo, token, appkey, url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1"):
|
||||
self.websocket = websocket
|
||||
self.tts_fifo = tts_fifo
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.appkey = appkey
|
||||
self.sdk = None
|
||||
self.init_sdk()
|
||||
self.started = False
|
||||
self.count = 0
|
||||
|
||||
def init_sdk(self):
|
||||
# 配置回调函数
|
||||
self.sdk = nls.NlsStreamInputTtsSynthesizer(
|
||||
url=self.url,
|
||||
token=self.token,
|
||||
appkey=self.appkey,
|
||||
on_data=self.on_data,
|
||||
on_sentence_begin=self.on_sentence_begin,
|
||||
on_sentence_synthesis=self.on_sentence_synthesis,
|
||||
on_sentence_end=self.on_sentence_end,
|
||||
on_completed=self.on_completed,
|
||||
on_error=self.on_error,
|
||||
on_close=self.on_close,
|
||||
callback_args=[]
|
||||
)
|
||||
def on_data(self, data, *args):
|
||||
self.count += len(data)
|
||||
self.tts_fifo.append(data)
|
||||
#with open('tts_server.pcm', 'ab') as file:
|
||||
# file.write(data)
|
||||
def on_sentence_begin(self, message, *args):
|
||||
print('on sentence begin =>{}'.format(message))
|
||||
|
||||
def on_sentence_synthesis(self, message, *args):
|
||||
print('on sentence synthesis =>{}'.format(message))
|
||||
|
||||
def on_sentence_end(self, message, *args):
|
||||
print('on sentence end =>{}'.format(message))
|
||||
|
||||
def on_completed(self, message, *args):
|
||||
print('on completed =>{}'.format(message))
|
||||
|
||||
def on_error(self, message, *args):
|
||||
print('on_error args=>{}'.format(args))
|
||||
|
||||
def on_close(self, *args):
|
||||
print('on_close: args=>{}'.format(args))
|
||||
print("on message data cout: =>{}".format(self.count))
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
self.sdk.startStreamInputTts()
|
||||
self.started = True
|
||||
|
||||
def send_text(self, text):
|
||||
self.sdk.sendStreamInputTts(text)
|
||||
|
||||
async def stop(self):
|
||||
self.sdk.stopStreamInputTts()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -181,7 +246,20 @@ tokenizer = model_llm.kwargs["tokenizer"]
|
||||
|
||||
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
async def send_to_client(websocket, syntheszier, tts_fifo):
|
||||
# Sending tts data to the client
|
||||
while True:
|
||||
if websocket.open and syntheszier.started:
|
||||
try:
|
||||
if len(tts_fifo) > 0:
|
||||
await websocket.send(tts_fifo.popleft())
|
||||
else:
|
||||
await asyncio.sleep(0.01)
|
||||
except Exception as e:
|
||||
print(f"Error sending data to client: {e}")
|
||||
else:
|
||||
print("WebSocket connection is not open or syntheszier is not started.")
|
||||
break
|
||||
async def model_inference(
|
||||
websocket,
|
||||
audio_in,
|
||||
@ -192,6 +270,9 @@ async def model_inference(
|
||||
history=None,
|
||||
text_usr="",
|
||||
):
|
||||
fifo_queue = deque()
|
||||
synthesizer = NlsTtsSynthesizer(websocket=websocket, tts_fifo=fifo_queue, token="xxx", appkey="xxx")
|
||||
synthesizer.start()
|
||||
beg0 = time.time()
|
||||
if his_state is None:
|
||||
his_state = model_dict
|
||||
@ -248,6 +329,7 @@ async def model_inference(
|
||||
f"generated new text: {new_text}, time_fr_receive: {end_llm - beg0:.2f}, time_llm_decode: {end_llm - beg_llm:.2f}"
|
||||
)
|
||||
if len(new_text) > 0:
|
||||
synthesizer.send_text(new_text)
|
||||
res += new_text.replace("<|im_end|>", "")
|
||||
contents_i[-1]["content"] = res
|
||||
websocket.llm_state["contents_i"] = contents_i
|
||||
@ -264,7 +346,13 @@ async def model_inference(
|
||||
)
|
||||
# print(f"online: {message}")
|
||||
await websocket.send(message)
|
||||
if len(fifo_queue) > 0:
|
||||
while len(fifo_queue) > 0:
|
||||
await websocket.send(fifo_queue.popleft())
|
||||
|
||||
tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue))
|
||||
synthesizer.stop()
|
||||
await tts_to_client_task
|
||||
mode = "2pass-offline"
|
||||
message = json.dumps(
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user