mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch llm_dev_gzf into dev_gzf_deepspeed
Title: lora 本次代码评审主要涉及对一个基于WebSocket的语音识别与处理系统的更新,包括添加日志打印、修改默认参数、增加对LoRA模型的支持、调整错误处理逻辑、优化音频处理和文本显示逻辑,以及代码结构和注释的若干改进,旨在提升系统稳定性和用户体验。 Link: https://code.alibaba-inc.com/zhifu.gzf/FunASR/codereview/18202582
This commit is contained in:
commit
6aa4f30555
@ -972,11 +972,17 @@ class LLMASR4(nn.Module):
|
||||
|
||||
lora_init_param_path = lora_conf.get("init_param_path", None)
|
||||
if lora_init_param_path is not None:
|
||||
logging.info(f"lora_init_param_path: {lora_init_param_path}")
|
||||
model = PeftModel.from_pretrained(model, lora_init_param_path)
|
||||
for name, param in model.named_parameters():
|
||||
if not lora_conf.get("freeze_lora", False):
|
||||
if "lora_" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
peft_config = LoraConfig(**lora_conf)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
if llm_conf.get("activation_checkpoint", False):
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
@ -225,6 +225,14 @@ class Trainer:
|
||||
ckpt_name = f"ds-model.pt.ep{epoch}.{step}"
|
||||
filename = os.path.join(self.output_dir, ckpt_name)
|
||||
|
||||
if self.use_lora:
|
||||
lora_outdir = f"{self.output_dir}/lora-{ckpt_name}"
|
||||
os.makedirs(lora_outdir, exist_ok=True)
|
||||
if hasattr(model, "module"):
|
||||
model.module.llm.save_pretrained(lora_outdir)
|
||||
else:
|
||||
model.llm.save_pretrained(lora_outdir)
|
||||
|
||||
# torch.save(state, filename)
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state)
|
||||
|
||||
@ -20,7 +20,7 @@ parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="localhost", required=False, help="host ip, localhost, 0.0.0.0"
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
|
||||
parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port")
|
||||
parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk")
|
||||
parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk")
|
||||
parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk")
|
||||
@ -42,7 +42,7 @@ parser.add_argument(
|
||||
parser.add_argument("--thread_num", type=int, default=1, help="thread_num")
|
||||
parser.add_argument("--words_max_print", type=int, default=10000, help="chunk")
|
||||
parser.add_argument("--output_dir", type=str, default=None, help="output_dir")
|
||||
parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl")
|
||||
parser.add_argument("--ssl", type=int, default=0, help="1 for ssl connect, 0 for no ssl")
|
||||
parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn")
|
||||
parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass")
|
||||
|
||||
@ -238,7 +238,7 @@ async def record_from_scp(chunk_begin, chunk_size):
|
||||
while not offline_msg_done:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
#await websocket.close()
|
||||
# await websocket.close()
|
||||
|
||||
|
||||
async def message(id):
|
||||
@ -255,8 +255,8 @@ async def message(id):
|
||||
ibest_writer = None
|
||||
try:
|
||||
timestamp = int(time.time())
|
||||
file_name = f'tts_client_{timestamp}.pcm'
|
||||
file = open(file_name, 'wb')
|
||||
file_name = f"tts_client_{timestamp}.pcm"
|
||||
file = open(file_name, "wb")
|
||||
while True:
|
||||
meg = await websocket.recv()
|
||||
if isinstance(meg, bytes):
|
||||
@ -306,7 +306,7 @@ async def message(id):
|
||||
text_print_2pass_online = ""
|
||||
text_print = text_print_2pass_offline + "{}".format(text)
|
||||
text_print_2pass_offline += "{}".format(text)
|
||||
#text_print = text_print[-args.words_max_print :]
|
||||
# text_print = text_print[-args.words_max_print :]
|
||||
os.system("clear")
|
||||
print("tts_count len: =>{}".format(tts_count))
|
||||
print("\rpid" + str(id) + ": " + text_print)
|
||||
@ -316,6 +316,8 @@ async def message(id):
|
||||
print("Exception:", e)
|
||||
# traceback.print_exc()
|
||||
# await websocket.close()
|
||||
|
||||
|
||||
async def ws_client(id, chunk_begin, chunk_size):
|
||||
if args.audio_in is None:
|
||||
chunk_begin = 0
|
||||
|
||||
@ -11,8 +11,16 @@ import nls
|
||||
from collections import deque
|
||||
import threading
|
||||
|
||||
|
||||
class NlsTtsSynthesizer:
|
||||
def __init__(self, websocket, tts_fifo, token, appkey, url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1"):
|
||||
def __init__(
|
||||
self,
|
||||
websocket,
|
||||
tts_fifo,
|
||||
token,
|
||||
appkey,
|
||||
url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1",
|
||||
):
|
||||
self.websocket = websocket
|
||||
self.tts_fifo = tts_fifo
|
||||
self.url = url
|
||||
@ -36,30 +44,33 @@ class NlsTtsSynthesizer:
|
||||
on_completed=self.on_completed,
|
||||
on_error=self.on_error,
|
||||
on_close=self.on_close,
|
||||
callback_args=[]
|
||||
callback_args=[],
|
||||
)
|
||||
|
||||
def on_data(self, data, *args):
|
||||
self.count += len(data)
|
||||
print(f"cout: {self.count}")
|
||||
self.tts_fifo.append(data)
|
||||
#with open('tts_server.pcm', 'ab') as file:
|
||||
# with open('tts_server.pcm', 'ab') as file:
|
||||
# file.write(data)
|
||||
|
||||
def on_sentence_begin(self, message, *args):
|
||||
print('on sentence begin =>{}'.format(message))
|
||||
print("on sentence begin =>{}".format(message))
|
||||
|
||||
def on_sentence_synthesis(self, message, *args):
|
||||
print('on sentence synthesis =>{}'.format(message))
|
||||
print("on sentence synthesis =>{}".format(message))
|
||||
|
||||
def on_sentence_end(self, message, *args):
|
||||
print('on sentence end =>{}'.format(message))
|
||||
print("on sentence end =>{}".format(message))
|
||||
|
||||
def on_completed(self, message, *args):
|
||||
print('on completed =>{}'.format(message))
|
||||
print("on completed =>{}".format(message))
|
||||
|
||||
def on_error(self, message, *args):
|
||||
print('on_error args=>{}'.format(args))
|
||||
print("on_error args=>{}".format(args))
|
||||
|
||||
def on_close(self, *args):
|
||||
print('on_close: args=>{}'.format(args))
|
||||
print("on_close: args=>{}".format(args))
|
||||
print("on message data cout: =>{}".format(self.count))
|
||||
self.started = False
|
||||
|
||||
@ -68,16 +79,18 @@ class NlsTtsSynthesizer:
|
||||
self.started = True
|
||||
|
||||
def send_text(self, text):
|
||||
print(f"text: {text}")
|
||||
self.sdk.sendStreamInputTts(text)
|
||||
|
||||
async def stop(self):
|
||||
self.sdk.stopStreamInputTts()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
|
||||
"--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0"
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
|
||||
parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port")
|
||||
parser.add_argument(
|
||||
"--asr_model",
|
||||
type=str,
|
||||
@ -124,16 +137,6 @@ websocket_users = set()
|
||||
print("model loading")
|
||||
from funasr import AutoModel
|
||||
|
||||
# # asr
|
||||
# model_asr = AutoModel(
|
||||
# model=args.asr_model,
|
||||
# model_revision=args.asr_model_revision,
|
||||
# ngpu=args.ngpu,
|
||||
# ncpu=args.ncpu,
|
||||
# device=args.device,
|
||||
# disable_pbar=False,
|
||||
# disable_log=True,
|
||||
# )
|
||||
|
||||
# vad
|
||||
model_vad = AutoModel(
|
||||
@ -147,26 +150,6 @@ model_vad = AutoModel(
|
||||
# chunk_size=60,
|
||||
)
|
||||
|
||||
# async def async_asr(websocket, audio_in):
|
||||
# if len(audio_in) > 0:
|
||||
# # print(len(audio_in))
|
||||
# print(type(audio_in))
|
||||
# rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
|
||||
# print("offline_asr, ", rec_result)
|
||||
#
|
||||
#
|
||||
# if len(rec_result["text"]) > 0:
|
||||
# # print("offline", rec_result)
|
||||
# mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
# message = json.dumps(
|
||||
# {
|
||||
# "mode": mode,
|
||||
# "text": rec_result["text"],
|
||||
# "wav_name": websocket.wav_name,
|
||||
# "is_final": websocket.is_speaking,
|
||||
# }
|
||||
# )
|
||||
# await websocket.send(message)
|
||||
|
||||
import os
|
||||
|
||||
@ -205,20 +188,26 @@ if "key" in os.environ:
|
||||
key = os.environ["key"]
|
||||
api.login(key)
|
||||
|
||||
appkey = "xxx"
|
||||
appkey_token = "xxx"
|
||||
if "appkey" in os.environ:
|
||||
appkey = os.environ["appkey"]
|
||||
appkey_token = os.environ["appkey_token"]
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
|
||||
# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
|
||||
# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
|
||||
os.environ["MODELSCOPE_CACHE"] = "/mnt/workspace"
|
||||
llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
|
||||
audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
|
||||
|
||||
llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
|
||||
audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope"
|
||||
# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
|
||||
# audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope"
|
||||
|
||||
device = "cuda:0"
|
||||
|
||||
all_file_paths = [
|
||||
"/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope"
|
||||
# "FunAudioLLM/Speech2Text_Align_V0712",
|
||||
# "/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope"
|
||||
"FunAudioLLM/Speech2Text_Align_V0712",
|
||||
# "FunAudioLLM/Speech2Text_Align_V0718",
|
||||
# "FunAudioLLM/Speech2Text_Align_V0628",
|
||||
]
|
||||
@ -246,6 +235,7 @@ tokenizer = model_llm.kwargs["tokenizer"]
|
||||
|
||||
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
async def send_to_client(websocket, syntheszier, tts_fifo):
|
||||
# Sending tts data to the client
|
||||
while True:
|
||||
@ -260,6 +250,8 @@ async def send_to_client(websocket, syntheszier, tts_fifo):
|
||||
else:
|
||||
print("WebSocket connection is not open or syntheszier is not started.")
|
||||
break
|
||||
|
||||
|
||||
async def model_inference(
|
||||
websocket,
|
||||
audio_in,
|
||||
@ -271,7 +263,9 @@ async def model_inference(
|
||||
text_usr="",
|
||||
):
|
||||
fifo_queue = deque()
|
||||
synthesizer = NlsTtsSynthesizer(websocket=websocket, tts_fifo=fifo_queue, token="xxx", appkey="xxx")
|
||||
synthesizer = NlsTtsSynthesizer(
|
||||
websocket=websocket, tts_fifo=fifo_queue, token=appkey_token, appkey=appkey
|
||||
)
|
||||
synthesizer.start()
|
||||
beg0 = time.time()
|
||||
if his_state is None:
|
||||
@ -329,8 +323,10 @@ async def model_inference(
|
||||
f"generated new text: {new_text}, time_fr_receive: {end_llm - beg0:.2f}, time_llm_decode: {end_llm - beg_llm:.2f}"
|
||||
)
|
||||
if len(new_text) > 0:
|
||||
new_text = new_text.replace("<|im_end|>", "")
|
||||
res += new_text
|
||||
synthesizer.send_text(new_text)
|
||||
res += new_text.replace("<|im_end|>", "")
|
||||
|
||||
contents_i[-1]["content"] = res
|
||||
websocket.llm_state["contents_i"] = contents_i
|
||||
# history[-1][1] = res
|
||||
@ -341,7 +337,7 @@ async def model_inference(
|
||||
"mode": mode,
|
||||
"text": new_text,
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
"is_final": False,
|
||||
}
|
||||
)
|
||||
# print(f"online: {message}")
|
||||
@ -350,6 +346,7 @@ async def model_inference(
|
||||
while len(fifo_queue) > 0:
|
||||
await websocket.send(fifo_queue.popleft())
|
||||
|
||||
# synthesizer.send_text(res)
|
||||
tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue))
|
||||
synthesizer.stop()
|
||||
await tts_to_client_task
|
||||
@ -359,7 +356,7 @@ async def model_inference(
|
||||
"mode": mode,
|
||||
"text": res,
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
"is_final": True,
|
||||
}
|
||||
)
|
||||
# print(f"offline: {message}")
|
||||
@ -460,7 +457,8 @@ async def ws_serve(websocket, path):
|
||||
frames_asr = []
|
||||
frames_asr.extend(frames_pre)
|
||||
# asr punc offline
|
||||
if speech_end_i != -1 or not websocket.is_speaking:
|
||||
# if speech_end_i != -1 or not websocket.is_speaking:
|
||||
if not websocket.is_speaking:
|
||||
# print("vad end point")
|
||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
||||
audio_in = b"".join(frames_asr)
|
||||
@ -506,7 +504,7 @@ async def async_vad(websocket, audio_in):
|
||||
return speech_start, speech_end
|
||||
|
||||
|
||||
if len(args.certfile) > 0:
|
||||
if False: # len(args.certfile) > 0:
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
|
||||
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
|
||||
|
||||
@ -3,6 +3,7 @@ import asyncio
|
||||
import json
|
||||
import websockets
|
||||
import time
|
||||
from datetime import datetime
|
||||
import argparse
|
||||
import ssl
|
||||
import numpy as np
|
||||
@ -14,7 +15,7 @@ from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
|
||||
"--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0"
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
|
||||
parser.add_argument(
|
||||
@ -30,14 +31,14 @@ parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
|
||||
parser.add_argument(
|
||||
"--certfile",
|
||||
type=str,
|
||||
default="../../ssl_key/server.crt",
|
||||
default="",
|
||||
required=False,
|
||||
help="certfile for ssl",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
default="../../ssl_key/server.key",
|
||||
default="ssl_key/server.key",
|
||||
required=False,
|
||||
help="keyfile for ssl",
|
||||
)
|
||||
@ -61,20 +62,25 @@ model_vad = AutoModel(
|
||||
)
|
||||
|
||||
api = HubApi()
|
||||
key = "ed70b703-9ec7-44b8-b5ce-5f4527719810"
|
||||
api.login(key)
|
||||
if "key" in os.environ:
|
||||
key = os.environ["key"]
|
||||
api.login(key)
|
||||
api.login(key)
|
||||
|
||||
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
|
||||
# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
|
||||
# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
|
||||
llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
|
||||
audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
|
||||
|
||||
llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
|
||||
audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
|
||||
# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
|
||||
# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
|
||||
device = "cuda:0"
|
||||
all_file_paths = [
|
||||
"/nfs/yangyexin.yyx/init_model/qwen2_7b_mmt_v14_20240830/",
|
||||
"/nfs/yangyexin.yyx/init_model/audiolm_v14_20240824_train_encoder_all_20240822_lr1e-4_warmup2350/"
|
||||
"FunAudioLLM/qwen2_7b_mmt_v14_20240830",
|
||||
"FunAudioLLM/audiolm_v11_20240807",
|
||||
"FunAudioLLM/Speech2Text_Align_V0712",
|
||||
"FunAudioLLM/Speech2Text_Align_V0718",
|
||||
"FunAudioLLM/Speech2Text_Align_V0628",
|
||||
]
|
||||
|
||||
llm_kwargs = {"num_beams": 1, "do_sample": False}
|
||||
@ -85,6 +91,7 @@ MAX_ITER_PER_CHUNK = 20
|
||||
|
||||
ckpt_dir = all_file_paths[0]
|
||||
|
||||
|
||||
def contains_lora_folder(directory):
|
||||
for name in os.listdir(directory):
|
||||
full_path = os.path.join(directory, name)
|
||||
@ -92,7 +99,10 @@ def contains_lora_folder(directory):
|
||||
return full_path
|
||||
return None
|
||||
|
||||
|
||||
ckpt_dir = snapshot_download(ckpt_dir, cache_dir=None, revision="master")
|
||||
lora_folder = contains_lora_folder(ckpt_dir)
|
||||
|
||||
if lora_folder is not None:
|
||||
model_llm = AutoModel(
|
||||
model=ckpt_dir,
|
||||
@ -102,7 +112,11 @@ if lora_folder is not None:
|
||||
llm_dtype="bf16",
|
||||
max_length=1024,
|
||||
llm_kwargs=llm_kwargs,
|
||||
llm_conf={"init_param_path": llm_dir, "lora_conf": {"init_param_path": lora_folder}, "load_kwargs": {"attn_implementation": "eager"}},
|
||||
llm_conf={
|
||||
"init_param_path": llm_dir,
|
||||
"lora_conf": {"init_param_path": lora_folder},
|
||||
"load_kwargs": {"attn_implementation": "eager"},
|
||||
},
|
||||
tokenizer_conf={"init_param_path": llm_dir},
|
||||
audio_encoder=audio_encoder_dir,
|
||||
)
|
||||
@ -128,6 +142,7 @@ model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
|
||||
|
||||
print("model loaded! only support one client at the same time now!!!!")
|
||||
|
||||
|
||||
def load_bytes(input):
|
||||
middle_data = np.frombuffer(input, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
@ -143,7 +158,12 @@ def load_bytes(input):
|
||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
||||
return array
|
||||
|
||||
async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None):
|
||||
|
||||
async def streaming_transcribe(
|
||||
websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None
|
||||
):
|
||||
current_time = datetime.now()
|
||||
print("DEBUG:" + str(current_time) + " call streaming_transcribe function:")
|
||||
if his_state is None:
|
||||
his_state = model_dict
|
||||
model = his_state["model"]
|
||||
@ -157,17 +177,21 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
else:
|
||||
previous_asr_text = websocket.streaming_state.get("previous_asr_text", "")
|
||||
previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "")
|
||||
previous_vad_onscreen_asr_text = websocket.streaming_state.get("previous_vad_onscreen_asr_text", "")
|
||||
previous_vad_onscreen_s2tt_text = websocket.streaming_state.get("previous_vad_onscreen_s2tt_text", "")
|
||||
|
||||
previous_vad_onscreen_asr_text = websocket.streaming_state.get(
|
||||
"previous_vad_onscreen_asr_text", ""
|
||||
)
|
||||
previous_vad_onscreen_s2tt_text = websocket.streaming_state.get(
|
||||
"previous_vad_onscreen_s2tt_text", ""
|
||||
)
|
||||
|
||||
if asr_prompt is None or asr_prompt == "":
|
||||
asr_prompt = "Copy:"
|
||||
asr_prompt = "Speech transcription:"
|
||||
if s2tt_prompt is None or s2tt_prompt == "":
|
||||
s2tt_prompt = "Translate the following sentence into English:"
|
||||
s2tt_prompt = "Translate into English:"
|
||||
|
||||
audio_seconds = load_bytes(audio_in).shape[0] / 16000
|
||||
print(f"Streaming audio length: {audio_seconds} seconds")
|
||||
|
||||
|
||||
asr_content = []
|
||||
system_prompt = "You are a helpful assistant."
|
||||
asr_content.append({"role": "system", "content": system_prompt})
|
||||
@ -185,12 +209,24 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
|
||||
streaming_time_beg = time.time()
|
||||
inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
|
||||
[asr_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
|
||||
[asr_content],
|
||||
None,
|
||||
"test_demo",
|
||||
tokenizer,
|
||||
frontend,
|
||||
device=device,
|
||||
infer_with_assistant_input=True,
|
||||
)
|
||||
model_asr_inputs = {}
|
||||
model_asr_inputs["inputs_embeds"] = inputs_asr_embeds
|
||||
inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
|
||||
[s2tt_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
|
||||
[s2tt_content],
|
||||
None,
|
||||
"test_demo",
|
||||
tokenizer,
|
||||
frontend,
|
||||
device=device,
|
||||
infer_with_assistant_input=True,
|
||||
)
|
||||
model_s2tt_inputs = {}
|
||||
model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds
|
||||
@ -208,7 +244,7 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
s2tt_generation_kwargs.update(llm_kwargs)
|
||||
s2tt_thread = Thread(target=model.llm.generate, kwargs=s2tt_generation_kwargs)
|
||||
s2tt_thread.start()
|
||||
|
||||
|
||||
onscreen_asr_res = previous_asr_text
|
||||
onscreen_s2tt_res = previous_s2tt_text
|
||||
|
||||
@ -220,8 +256,8 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
is_s2tt_repetition = False
|
||||
|
||||
for new_asr_text in asr_streamer:
|
||||
print(f"generated new asr text: {new_asr_text}")
|
||||
|
||||
current_time = datetime.now()
|
||||
print("DEBUG: " + str(current_time) + " " + f"generated new asr text: {new_asr_text}")
|
||||
if len(new_asr_text) > 0:
|
||||
onscreen_asr_res += new_asr_text.replace("<|im_end|>", "")
|
||||
if len(new_asr_text.replace("<|im_end|>", "")) > 0:
|
||||
@ -233,7 +269,13 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
if remain_s2tt_text:
|
||||
try:
|
||||
new_s2tt_text = next(s2tt_streamer)
|
||||
print(f"generated new s2tt text: {new_s2tt_text}")
|
||||
current_time = datetime.now()
|
||||
print(
|
||||
"DEBUG: "
|
||||
+ str(current_time)
|
||||
+ " "
|
||||
+ f"generated new s2tt text: {new_s2tt_text}"
|
||||
)
|
||||
s2tt_iter_cnt += 1
|
||||
if len(new_s2tt_text) > 0:
|
||||
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
|
||||
@ -241,16 +283,16 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
new_s2tt_text = ""
|
||||
remain_s2tt_text = False
|
||||
pass
|
||||
|
||||
|
||||
if len(new_asr_text) > 0 or len(new_s2tt_text) > 0:
|
||||
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
|
||||
fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text
|
||||
unfix_asr_part = all_asr_res[len(fix_asr_part):]
|
||||
return_asr_res = fix_asr_part + "<em>"+ unfix_asr_part + "</em>"
|
||||
unfix_asr_part = all_asr_res[len(fix_asr_part) :]
|
||||
return_asr_res = fix_asr_part + "<em>" + unfix_asr_part + "</em>"
|
||||
all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
|
||||
fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text
|
||||
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):]
|
||||
return_s2tt_res = fix_s2tt_part + "<em>"+ unfix_s2tt_part + "</em>"
|
||||
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part) :]
|
||||
return_s2tt_res = fix_s2tt_part + "<em>" + unfix_s2tt_part + "</em>"
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": "online",
|
||||
@ -261,14 +303,19 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
|
||||
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
|
||||
websocket.streaming_state["onscreen_asr_res"] = (
|
||||
previous_vad_onscreen_asr_text + onscreen_asr_res
|
||||
)
|
||||
websocket.streaming_state["onscreen_s2tt_res"] = (
|
||||
previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
|
||||
)
|
||||
|
||||
|
||||
if remain_s2tt_text:
|
||||
for new_s2tt_text in s2tt_streamer:
|
||||
print(f"generated new s2tt text: {new_s2tt_text}")
|
||||
|
||||
current_time = datetime.now()
|
||||
print(
|
||||
"DEBUG: " + str(current_time) + " " + f"generated new s2tt text: {new_s2tt_text}"
|
||||
)
|
||||
if len(new_s2tt_text) > 0:
|
||||
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
|
||||
if len(new_s2tt_text.replace("<|im_end|>", "")) > 0:
|
||||
@ -276,16 +323,16 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
if s2tt_iter_cnt > MAX_ITER_PER_CHUNK:
|
||||
is_s2tt_repetition = True
|
||||
break
|
||||
|
||||
|
||||
if len(new_s2tt_text) > 0:
|
||||
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
|
||||
fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text
|
||||
unfix_asr_part = all_asr_res[len(fix_asr_part):]
|
||||
return_asr_res = fix_asr_part + "<em>"+ unfix_asr_part + "</em>"
|
||||
unfix_asr_part = all_asr_res[len(fix_asr_part) :]
|
||||
return_asr_res = fix_asr_part + "<em>" + unfix_asr_part + "</em>"
|
||||
all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
|
||||
fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text
|
||||
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):]
|
||||
return_s2tt_res = fix_s2tt_part + "<em>"+ unfix_s2tt_part + "</em>"
|
||||
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part) :]
|
||||
return_s2tt_res = fix_s2tt_part + "<em>" + unfix_s2tt_part + "</em>"
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": "online",
|
||||
@ -296,8 +343,12 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
|
||||
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
|
||||
websocket.streaming_state["onscreen_asr_res"] = (
|
||||
previous_vad_onscreen_asr_text + onscreen_asr_res
|
||||
)
|
||||
websocket.streaming_state["onscreen_s2tt_res"] = (
|
||||
previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
|
||||
)
|
||||
|
||||
streaming_time_end = time.time()
|
||||
print(f"Streaming inference time: {streaming_time_end - streaming_time_beg}")
|
||||
@ -307,16 +358,24 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
|
||||
if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_asr_repetition:
|
||||
pre_previous_asr_text = previous_asr_text
|
||||
previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN]).replace("<EFBFBD>", "")
|
||||
previous_asr_text = tokenizer.decode(
|
||||
tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN]
|
||||
).replace("<EFBFBD>", "")
|
||||
if len(previous_asr_text) <= len(pre_previous_asr_text):
|
||||
previous_asr_text = pre_previous_asr_text
|
||||
elif is_asr_repetition:
|
||||
pass
|
||||
else:
|
||||
previous_asr_text = ""
|
||||
if s2tt_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_s2tt_repetition:
|
||||
if (
|
||||
s2tt_text_len > UNFIX_LEN
|
||||
and audio_seconds > MIN_LEN_SEC_AUDIO_FIX
|
||||
and not is_s2tt_repetition
|
||||
):
|
||||
pre_previous_s2tt_text = previous_s2tt_text
|
||||
previous_s2tt_text = tokenizer.decode(tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN]).replace("<EFBFBD>", "")
|
||||
previous_s2tt_text = tokenizer.decode(
|
||||
tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN]
|
||||
).replace("<EFBFBD>", "")
|
||||
if len(previous_s2tt_text) <= len(pre_previous_s2tt_text):
|
||||
previous_s2tt_text = pre_previous_s2tt_text
|
||||
elif is_s2tt_repetition:
|
||||
@ -325,9 +384,13 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
|
||||
previous_s2tt_text = ""
|
||||
|
||||
websocket.streaming_state["previous_asr_text"] = previous_asr_text
|
||||
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
|
||||
websocket.streaming_state["onscreen_asr_res"] = (
|
||||
previous_vad_onscreen_asr_text + onscreen_asr_res
|
||||
)
|
||||
websocket.streaming_state["previous_s2tt_text"] = previous_s2tt_text
|
||||
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
|
||||
websocket.streaming_state["onscreen_s2tt_res"] = (
|
||||
previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
|
||||
)
|
||||
|
||||
print("fix asr part:", previous_asr_text)
|
||||
print("fix s2tt part:", previous_s2tt_text)
|
||||
@ -383,6 +446,13 @@ async def ws_serve(websocket, path):
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
if isinstance(message, str):
|
||||
current_time = datetime.now()
|
||||
print("DEBUG:" + str(current_time) + " received message:", message)
|
||||
else:
|
||||
current_time = datetime.now()
|
||||
print("DEBUG:" + str(current_time) + " received audio bytes:")
|
||||
|
||||
if isinstance(message, str):
|
||||
messagejson = json.loads(message)
|
||||
|
||||
@ -403,11 +473,11 @@ async def ws_serve(websocket, path):
|
||||
if "asr_prompt" in messagejson:
|
||||
asr_prompt = messagejson["asr_prompt"]
|
||||
else:
|
||||
asr_prompt = "Copy:"
|
||||
asr_prompt = "Speech transcription:"
|
||||
if "s2tt_prompt" in messagejson:
|
||||
s2tt_prompt = messagejson["s2tt_prompt"]
|
||||
else:
|
||||
s2tt_prompt = "Translate the following sentence into English:"
|
||||
s2tt_prompt = "Translate into English:"
|
||||
|
||||
websocket.status_dict_vad["chunk_size"] = int(
|
||||
chunk_size[1] * 60 / websocket.chunk_interval
|
||||
@ -421,19 +491,20 @@ async def ws_serve(websocket, path):
|
||||
# asr online
|
||||
websocket.streaming_state["is_final"] = speech_end_i != -1
|
||||
if (
|
||||
(len(frames_asr) % websocket.chunk_interval == 0
|
||||
or websocket.streaming_state["is_final"])
|
||||
and len(frames_asr) != 0
|
||||
):
|
||||
len(frames_asr) % websocket.chunk_interval == 0
|
||||
or websocket.streaming_state["is_final"]
|
||||
) and len(frames_asr) != 0:
|
||||
audio_in = b"".join(frames_asr)
|
||||
try:
|
||||
await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt)
|
||||
await streaming_transcribe(
|
||||
websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"error in streaming, {e}")
|
||||
print(f"error in streaming, {websocket.streaming_state}")
|
||||
if speech_start:
|
||||
frames_asr.append(message)
|
||||
|
||||
|
||||
# vad online
|
||||
try:
|
||||
speech_start_i, speech_end_i = await async_vad(websocket, message)
|
||||
@ -450,7 +521,9 @@ async def ws_serve(websocket, path):
|
||||
if speech_end_i != -1 or not websocket.is_speaking:
|
||||
audio_in = b"".join(frames_asr)
|
||||
try:
|
||||
await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt)
|
||||
await streaming_transcribe(
|
||||
websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"error in streaming, {e}")
|
||||
print(f"error in streaming, {websocket.streaming_state}")
|
||||
@ -460,16 +533,37 @@ async def ws_serve(websocket, path):
|
||||
websocket.streaming_state["previous_s2tt_text"] = ""
|
||||
now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "")
|
||||
now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "")
|
||||
if len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH:
|
||||
if now_onscreen_asr_res.endswith(".") or now_onscreen_asr_res.endswith("?") or now_onscreen_asr_res.endswith("!"):
|
||||
if (
|
||||
len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1]))
|
||||
< MIN_LEN_PER_PARAGRAPH
|
||||
or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1]))
|
||||
< MIN_LEN_PER_PARAGRAPH
|
||||
):
|
||||
if (
|
||||
now_onscreen_asr_res.endswith(".")
|
||||
or now_onscreen_asr_res.endswith("?")
|
||||
or now_onscreen_asr_res.endswith("!")
|
||||
):
|
||||
now_onscreen_asr_res += " "
|
||||
if now_onscreen_s2tt_res.endswith(".") or now_onscreen_s2tt_res.endswith("?") or now_onscreen_s2tt_res.endswith("!"):
|
||||
if (
|
||||
now_onscreen_s2tt_res.endswith(".")
|
||||
or now_onscreen_s2tt_res.endswith("?")
|
||||
or now_onscreen_s2tt_res.endswith("!")
|
||||
):
|
||||
now_onscreen_s2tt_res += " "
|
||||
websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res
|
||||
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res
|
||||
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
|
||||
now_onscreen_asr_res
|
||||
)
|
||||
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
|
||||
now_onscreen_s2tt_res
|
||||
)
|
||||
else:
|
||||
websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res + "\n\n"
|
||||
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res + "\n\n"
|
||||
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
|
||||
now_onscreen_asr_res + "\n\n"
|
||||
)
|
||||
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
|
||||
now_onscreen_s2tt_res + "\n\n"
|
||||
)
|
||||
if not websocket.is_speaking:
|
||||
websocket.vad_pre_idx = 0
|
||||
frames = []
|
||||
@ -491,6 +585,8 @@ async def ws_serve(websocket, path):
|
||||
|
||||
|
||||
async def async_vad(websocket, audio_in):
|
||||
current_time = datetime.now()
|
||||
print("DEBUG:" + str(current_time) + " call vad function:")
|
||||
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
|
||||
# print(segments_result)
|
||||
|
||||
@ -506,7 +602,7 @@ async def async_vad(websocket, audio_in):
|
||||
return speech_start, speech_end
|
||||
|
||||
|
||||
if len(args.certfile) > 0:
|
||||
if False:
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
|
||||
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
|
||||
|
||||
Loading…
Reference in New Issue
Block a user