Merge branch llm_dev_gzf into dev_gzf_deepspeed

Title: lora 

本次代码评审主要涉及对一个基于WebSocket的语音识别与处理系统的更新,包括添加日志打印、修改默认参数、增加对LoRA模型的支持、调整错误处理逻辑、优化音频处理和文本显示逻辑,以及代码结构和注释的若干改进,旨在提升系统稳定性和用户体验。
Link: https://code.alibaba-inc.com/zhifu.gzf/FunASR/codereview/18202582
This commit is contained in:
yangyexin.yyx 2024-09-03 13:55:59 +08:00
commit 6aa4f30555
5 changed files with 234 additions and 124 deletions

View File

@ -972,11 +972,17 @@ class LLMASR4(nn.Module):
lora_init_param_path = lora_conf.get("init_param_path", None)
if lora_init_param_path is not None:
logging.info(f"lora_init_param_path: {lora_init_param_path}")
model = PeftModel.from_pretrained(model, lora_init_param_path)
for name, param in model.named_parameters():
if not lora_conf.get("freeze_lora", False):
if "lora_" in name:
param.requires_grad = True
else:
peft_config = LoraConfig(**lora_conf)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model.print_trainable_parameters()
if llm_conf.get("activation_checkpoint", False):
model.gradient_checkpointing_enable()

View File

@ -225,6 +225,14 @@ class Trainer:
ckpt_name = f"ds-model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
if self.use_lora:
lora_outdir = f"{self.output_dir}/lora-{ckpt_name}"
os.makedirs(lora_outdir, exist_ok=True)
if hasattr(model, "module"):
model.module.llm.save_pretrained(lora_outdir)
else:
model.llm.save_pretrained(lora_outdir)
# torch.save(state, filename)
with torch.no_grad():
model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state)

View File

@ -20,7 +20,7 @@ parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="localhost", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port")
parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk")
parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk")
parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk")
@ -42,7 +42,7 @@ parser.add_argument(
parser.add_argument("--thread_num", type=int, default=1, help="thread_num")
parser.add_argument("--words_max_print", type=int, default=10000, help="chunk")
parser.add_argument("--output_dir", type=str, default=None, help="output_dir")
parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl")
parser.add_argument("--ssl", type=int, default=0, help="1 for ssl connect, 0 for no ssl")
parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn")
parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass")
@ -238,7 +238,7 @@ async def record_from_scp(chunk_begin, chunk_size):
while not offline_msg_done:
await asyncio.sleep(1)
#await websocket.close()
# await websocket.close()
async def message(id):
@ -255,8 +255,8 @@ async def message(id):
ibest_writer = None
try:
timestamp = int(time.time())
file_name = f'tts_client_{timestamp}.pcm'
file = open(file_name, 'wb')
file_name = f"tts_client_{timestamp}.pcm"
file = open(file_name, "wb")
while True:
meg = await websocket.recv()
if isinstance(meg, bytes):
@ -306,7 +306,7 @@ async def message(id):
text_print_2pass_online = ""
text_print = text_print_2pass_offline + "{}".format(text)
text_print_2pass_offline += "{}".format(text)
#text_print = text_print[-args.words_max_print :]
# text_print = text_print[-args.words_max_print :]
os.system("clear")
print("tts_count len: =>{}".format(tts_count))
print("\rpid" + str(id) + ": " + text_print)
@ -316,6 +316,8 @@ async def message(id):
print("Exception:", e)
# traceback.print_exc()
# await websocket.close()
async def ws_client(id, chunk_begin, chunk_size):
if args.audio_in is None:
chunk_begin = 0

View File

@ -11,8 +11,16 @@ import nls
from collections import deque
import threading
class NlsTtsSynthesizer:
def __init__(self, websocket, tts_fifo, token, appkey, url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1"):
def __init__(
self,
websocket,
tts_fifo,
token,
appkey,
url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1",
):
self.websocket = websocket
self.tts_fifo = tts_fifo
self.url = url
@ -36,30 +44,33 @@ class NlsTtsSynthesizer:
on_completed=self.on_completed,
on_error=self.on_error,
on_close=self.on_close,
callback_args=[]
callback_args=[],
)
def on_data(self, data, *args):
self.count += len(data)
print(f"cout: {self.count}")
self.tts_fifo.append(data)
#with open('tts_server.pcm', 'ab') as file:
# with open('tts_server.pcm', 'ab') as file:
# file.write(data)
def on_sentence_begin(self, message, *args):
print('on sentence begin =>{}'.format(message))
print("on sentence begin =>{}".format(message))
def on_sentence_synthesis(self, message, *args):
print('on sentence synthesis =>{}'.format(message))
print("on sentence synthesis =>{}".format(message))
def on_sentence_end(self, message, *args):
print('on sentence end =>{}'.format(message))
print("on sentence end =>{}".format(message))
def on_completed(self, message, *args):
print('on completed =>{}'.format(message))
print("on completed =>{}".format(message))
def on_error(self, message, *args):
print('on_error args=>{}'.format(args))
print("on_error args=>{}".format(args))
def on_close(self, *args):
print('on_close: args=>{}'.format(args))
print("on_close: args=>{}".format(args))
print("on message data cout: =>{}".format(self.count))
self.started = False
@ -68,16 +79,18 @@ class NlsTtsSynthesizer:
self.started = True
def send_text(self, text):
print(f"text: {text}")
self.sdk.sendStreamInputTts(text)
async def stop(self):
self.sdk.stopStreamInputTts()
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
"--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port")
parser.add_argument(
"--asr_model",
type=str,
@ -124,16 +137,6 @@ websocket_users = set()
print("model loading")
from funasr import AutoModel
# # asr
# model_asr = AutoModel(
# model=args.asr_model,
# model_revision=args.asr_model_revision,
# ngpu=args.ngpu,
# ncpu=args.ncpu,
# device=args.device,
# disable_pbar=False,
# disable_log=True,
# )
# vad
model_vad = AutoModel(
@ -147,26 +150,6 @@ model_vad = AutoModel(
# chunk_size=60,
)
# async def async_asr(websocket, audio_in):
# if len(audio_in) > 0:
# # print(len(audio_in))
# print(type(audio_in))
# rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
# print("offline_asr, ", rec_result)
#
#
# if len(rec_result["text"]) > 0:
# # print("offline", rec_result)
# mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
# message = json.dumps(
# {
# "mode": mode,
# "text": rec_result["text"],
# "wav_name": websocket.wav_name,
# "is_final": websocket.is_speaking,
# }
# )
# await websocket.send(message)
import os
@ -205,20 +188,26 @@ if "key" in os.environ:
key = os.environ["key"]
api.login(key)
appkey = "xxx"
appkey_token = "xxx"
if "appkey" in os.environ:
appkey = os.environ["appkey"]
appkey_token = os.environ["appkey_token"]
from modelscope.hub.snapshot_download import snapshot_download
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
os.environ["MODELSCOPE_CACHE"] = "/mnt/workspace"
llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope"
# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
# audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope"
device = "cuda:0"
all_file_paths = [
"/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope"
# "FunAudioLLM/Speech2Text_Align_V0712",
# "/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope"
"FunAudioLLM/Speech2Text_Align_V0712",
# "FunAudioLLM/Speech2Text_Align_V0718",
# "FunAudioLLM/Speech2Text_Align_V0628",
]
@ -246,6 +235,7 @@ tokenizer = model_llm.kwargs["tokenizer"]
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
async def send_to_client(websocket, syntheszier, tts_fifo):
# Sending tts data to the client
while True:
@ -260,6 +250,8 @@ async def send_to_client(websocket, syntheszier, tts_fifo):
else:
print("WebSocket connection is not open or syntheszier is not started.")
break
async def model_inference(
websocket,
audio_in,
@ -271,7 +263,9 @@ async def model_inference(
text_usr="",
):
fifo_queue = deque()
synthesizer = NlsTtsSynthesizer(websocket=websocket, tts_fifo=fifo_queue, token="xxx", appkey="xxx")
synthesizer = NlsTtsSynthesizer(
websocket=websocket, tts_fifo=fifo_queue, token=appkey_token, appkey=appkey
)
synthesizer.start()
beg0 = time.time()
if his_state is None:
@ -329,8 +323,10 @@ async def model_inference(
f"generated new text {new_text}, time_fr_receive: {end_llm - beg0:.2f}, time_llm_decode: {end_llm - beg_llm:.2f}"
)
if len(new_text) > 0:
new_text = new_text.replace("<|im_end|>", "")
res += new_text
synthesizer.send_text(new_text)
res += new_text.replace("<|im_end|>", "")
contents_i[-1]["content"] = res
websocket.llm_state["contents_i"] = contents_i
# history[-1][1] = res
@ -341,7 +337,7 @@ async def model_inference(
"mode": mode,
"text": new_text,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
"is_final": False,
}
)
# print(f"online: {message}")
@ -350,6 +346,7 @@ async def model_inference(
while len(fifo_queue) > 0:
await websocket.send(fifo_queue.popleft())
# synthesizer.send_text(res)
tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue))
synthesizer.stop()
await tts_to_client_task
@ -359,7 +356,7 @@ async def model_inference(
"mode": mode,
"text": res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
"is_final": True,
}
)
# print(f"offline: {message}")
@ -460,7 +457,8 @@ async def ws_serve(websocket, path):
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
if speech_end_i != -1 or not websocket.is_speaking:
# if speech_end_i != -1 or not websocket.is_speaking:
if not websocket.is_speaking:
# print("vad end point")
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
@ -506,7 +504,7 @@ async def async_vad(websocket, audio_in):
return speech_start, speech_end
if len(args.certfile) > 0:
if False: # len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions

View File

@ -3,6 +3,7 @@ import asyncio
import json
import websockets
import time
from datetime import datetime
import argparse
import ssl
import numpy as np
@ -14,7 +15,7 @@ from modelscope.hub.snapshot_download import snapshot_download
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
"--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument(
@ -30,14 +31,14 @@ parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
parser.add_argument(
"--certfile",
type=str,
default="../../ssl_key/server.crt",
default="",
required=False,
help="certfile for ssl",
)
parser.add_argument(
"--keyfile",
type=str,
default="../../ssl_key/server.key",
default="ssl_key/server.key",
required=False,
help="keyfile for ssl",
)
@ -61,20 +62,25 @@ model_vad = AutoModel(
)
api = HubApi()
key = "ed70b703-9ec7-44b8-b5ce-5f4527719810"
api.login(key)
if "key" in os.environ:
key = os.environ["key"]
api.login(key)
api.login(key)
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
device = "cuda:0"
all_file_paths = [
"/nfs/yangyexin.yyx/init_model/qwen2_7b_mmt_v14_20240830/",
"/nfs/yangyexin.yyx/init_model/audiolm_v14_20240824_train_encoder_all_20240822_lr1e-4_warmup2350/"
"FunAudioLLM/qwen2_7b_mmt_v14_20240830",
"FunAudioLLM/audiolm_v11_20240807",
"FunAudioLLM/Speech2Text_Align_V0712",
"FunAudioLLM/Speech2Text_Align_V0718",
"FunAudioLLM/Speech2Text_Align_V0628",
]
llm_kwargs = {"num_beams": 1, "do_sample": False}
@ -85,6 +91,7 @@ MAX_ITER_PER_CHUNK = 20
ckpt_dir = all_file_paths[0]
def contains_lora_folder(directory):
for name in os.listdir(directory):
full_path = os.path.join(directory, name)
@ -92,7 +99,10 @@ def contains_lora_folder(directory):
return full_path
return None
ckpt_dir = snapshot_download(ckpt_dir, cache_dir=None, revision="master")
lora_folder = contains_lora_folder(ckpt_dir)
if lora_folder is not None:
model_llm = AutoModel(
model=ckpt_dir,
@ -102,7 +112,11 @@ if lora_folder is not None:
llm_dtype="bf16",
max_length=1024,
llm_kwargs=llm_kwargs,
llm_conf={"init_param_path": llm_dir, "lora_conf": {"init_param_path": lora_folder}, "load_kwargs": {"attn_implementation": "eager"}},
llm_conf={
"init_param_path": llm_dir,
"lora_conf": {"init_param_path": lora_folder},
"load_kwargs": {"attn_implementation": "eager"},
},
tokenizer_conf={"init_param_path": llm_dir},
audio_encoder=audio_encoder_dir,
)
@ -128,6 +142,7 @@ model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
print("model loaded! only support one client at the same time now!!!!")
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
@ -143,7 +158,12 @@ def load_bytes(input):
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None):
async def streaming_transcribe(
websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None
):
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " call streaming_transcribe function:")
if his_state is None:
his_state = model_dict
model = his_state["model"]
@ -157,17 +177,21 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
else:
previous_asr_text = websocket.streaming_state.get("previous_asr_text", "")
previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "")
previous_vad_onscreen_asr_text = websocket.streaming_state.get("previous_vad_onscreen_asr_text", "")
previous_vad_onscreen_s2tt_text = websocket.streaming_state.get("previous_vad_onscreen_s2tt_text", "")
previous_vad_onscreen_asr_text = websocket.streaming_state.get(
"previous_vad_onscreen_asr_text", ""
)
previous_vad_onscreen_s2tt_text = websocket.streaming_state.get(
"previous_vad_onscreen_s2tt_text", ""
)
if asr_prompt is None or asr_prompt == "":
asr_prompt = "Copy:"
asr_prompt = "Speech transcription:"
if s2tt_prompt is None or s2tt_prompt == "":
s2tt_prompt = "Translate the following sentence into English:"
s2tt_prompt = "Translate into English:"
audio_seconds = load_bytes(audio_in).shape[0] / 16000
print(f"Streaming audio length: {audio_seconds} seconds")
asr_content = []
system_prompt = "You are a helpful assistant."
asr_content.append({"role": "system", "content": system_prompt})
@ -185,12 +209,24 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
streaming_time_beg = time.time()
inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[asr_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
[asr_content],
None,
"test_demo",
tokenizer,
frontend,
device=device,
infer_with_assistant_input=True,
)
model_asr_inputs = {}
model_asr_inputs["inputs_embeds"] = inputs_asr_embeds
inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[s2tt_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
[s2tt_content],
None,
"test_demo",
tokenizer,
frontend,
device=device,
infer_with_assistant_input=True,
)
model_s2tt_inputs = {}
model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds
@ -208,7 +244,7 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
s2tt_generation_kwargs.update(llm_kwargs)
s2tt_thread = Thread(target=model.llm.generate, kwargs=s2tt_generation_kwargs)
s2tt_thread.start()
onscreen_asr_res = previous_asr_text
onscreen_s2tt_res = previous_s2tt_text
@ -220,8 +256,8 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
is_s2tt_repetition = False
for new_asr_text in asr_streamer:
print(f"generated new asr text {new_asr_text}")
current_time = datetime.now()
print("DEBUG: " + str(current_time) + " " + f"generated new asr text {new_asr_text}")
if len(new_asr_text) > 0:
onscreen_asr_res += new_asr_text.replace("<|im_end|>", "")
if len(new_asr_text.replace("<|im_end|>", "")) > 0:
@ -233,7 +269,13 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
if remain_s2tt_text:
try:
new_s2tt_text = next(s2tt_streamer)
print(f"generated new s2tt text {new_s2tt_text}")
current_time = datetime.now()
print(
"DEBUG: "
+ str(current_time)
+ " "
+ f"generated new s2tt text {new_s2tt_text}"
)
s2tt_iter_cnt += 1
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
@ -241,16 +283,16 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
new_s2tt_text = ""
remain_s2tt_text = False
pass
if len(new_asr_text) > 0 or len(new_s2tt_text) > 0:
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text
unfix_asr_part = all_asr_res[len(fix_asr_part):]
return_asr_res = fix_asr_part + "<em>"+ unfix_asr_part + "</em>"
unfix_asr_part = all_asr_res[len(fix_asr_part) :]
return_asr_res = fix_asr_part + "<em>" + unfix_asr_part + "</em>"
all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):]
return_s2tt_res = fix_s2tt_part + "<em>"+ unfix_s2tt_part + "</em>"
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part) :]
return_s2tt_res = fix_s2tt_part + "<em>" + unfix_s2tt_part + "</em>"
message = json.dumps(
{
"mode": "online",
@ -261,14 +303,19 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
}
)
await websocket.send(message)
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
websocket.streaming_state["onscreen_asr_res"] = (
previous_vad_onscreen_asr_text + onscreen_asr_res
)
websocket.streaming_state["onscreen_s2tt_res"] = (
previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
)
if remain_s2tt_text:
for new_s2tt_text in s2tt_streamer:
print(f"generated new s2tt text {new_s2tt_text}")
current_time = datetime.now()
print(
"DEBUG: " + str(current_time) + " " + f"generated new s2tt text {new_s2tt_text}"
)
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
if len(new_s2tt_text.replace("<|im_end|>", "")) > 0:
@ -276,16 +323,16 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
if s2tt_iter_cnt > MAX_ITER_PER_CHUNK:
is_s2tt_repetition = True
break
if len(new_s2tt_text) > 0:
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text
unfix_asr_part = all_asr_res[len(fix_asr_part):]
return_asr_res = fix_asr_part + "<em>"+ unfix_asr_part + "</em>"
unfix_asr_part = all_asr_res[len(fix_asr_part) :]
return_asr_res = fix_asr_part + "<em>" + unfix_asr_part + "</em>"
all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):]
return_s2tt_res = fix_s2tt_part + "<em>"+ unfix_s2tt_part + "</em>"
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part) :]
return_s2tt_res = fix_s2tt_part + "<em>" + unfix_s2tt_part + "</em>"
message = json.dumps(
{
"mode": "online",
@ -296,8 +343,12 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
}
)
await websocket.send(message)
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
websocket.streaming_state["onscreen_asr_res"] = (
previous_vad_onscreen_asr_text + onscreen_asr_res
)
websocket.streaming_state["onscreen_s2tt_res"] = (
previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
)
streaming_time_end = time.time()
print(f"Streaming inference time: {streaming_time_end - streaming_time_beg}")
@ -307,16 +358,24 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_asr_repetition:
pre_previous_asr_text = previous_asr_text
previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN]).replace("<EFBFBD>", "")
previous_asr_text = tokenizer.decode(
tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN]
).replace("<EFBFBD>", "")
if len(previous_asr_text) <= len(pre_previous_asr_text):
previous_asr_text = pre_previous_asr_text
elif is_asr_repetition:
pass
else:
previous_asr_text = ""
if s2tt_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_s2tt_repetition:
if (
s2tt_text_len > UNFIX_LEN
and audio_seconds > MIN_LEN_SEC_AUDIO_FIX
and not is_s2tt_repetition
):
pre_previous_s2tt_text = previous_s2tt_text
previous_s2tt_text = tokenizer.decode(tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN]).replace("<EFBFBD>", "")
previous_s2tt_text = tokenizer.decode(
tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN]
).replace("<EFBFBD>", "")
if len(previous_s2tt_text) <= len(pre_previous_s2tt_text):
previous_s2tt_text = pre_previous_s2tt_text
elif is_s2tt_repetition:
@ -325,9 +384,13 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
previous_s2tt_text = ""
websocket.streaming_state["previous_asr_text"] = previous_asr_text
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
websocket.streaming_state["onscreen_asr_res"] = (
previous_vad_onscreen_asr_text + onscreen_asr_res
)
websocket.streaming_state["previous_s2tt_text"] = previous_s2tt_text
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
websocket.streaming_state["onscreen_s2tt_res"] = (
previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
)
print("fix asr part:", previous_asr_text)
print("fix s2tt part:", previous_s2tt_text)
@ -383,6 +446,13 @@ async def ws_serve(websocket, path):
try:
async for message in websocket:
if isinstance(message, str):
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " received message:", message)
else:
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " received audio bytes:")
if isinstance(message, str):
messagejson = json.loads(message)
@ -403,11 +473,11 @@ async def ws_serve(websocket, path):
if "asr_prompt" in messagejson:
asr_prompt = messagejson["asr_prompt"]
else:
asr_prompt = "Copy:"
asr_prompt = "Speech transcription:"
if "s2tt_prompt" in messagejson:
s2tt_prompt = messagejson["s2tt_prompt"]
else:
s2tt_prompt = "Translate the following sentence into English:"
s2tt_prompt = "Translate into English:"
websocket.status_dict_vad["chunk_size"] = int(
chunk_size[1] * 60 / websocket.chunk_interval
@ -421,19 +491,20 @@ async def ws_serve(websocket, path):
# asr online
websocket.streaming_state["is_final"] = speech_end_i != -1
if (
(len(frames_asr) % websocket.chunk_interval == 0
or websocket.streaming_state["is_final"])
and len(frames_asr) != 0
):
len(frames_asr) % websocket.chunk_interval == 0
or websocket.streaming_state["is_final"]
) and len(frames_asr) != 0:
audio_in = b"".join(frames_asr)
try:
await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt)
await streaming_transcribe(
websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
)
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
if speech_start:
frames_asr.append(message)
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
@ -450,7 +521,9 @@ async def ws_serve(websocket, path):
if speech_end_i != -1 or not websocket.is_speaking:
audio_in = b"".join(frames_asr)
try:
await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt)
await streaming_transcribe(
websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
)
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
@ -460,16 +533,37 @@ async def ws_serve(websocket, path):
websocket.streaming_state["previous_s2tt_text"] = ""
now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "")
now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "")
if len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH:
if now_onscreen_asr_res.endswith(".") or now_onscreen_asr_res.endswith("?") or now_onscreen_asr_res.endswith("!"):
if (
len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1]))
< MIN_LEN_PER_PARAGRAPH
or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1]))
< MIN_LEN_PER_PARAGRAPH
):
if (
now_onscreen_asr_res.endswith(".")
or now_onscreen_asr_res.endswith("?")
or now_onscreen_asr_res.endswith("!")
):
now_onscreen_asr_res += " "
if now_onscreen_s2tt_res.endswith(".") or now_onscreen_s2tt_res.endswith("?") or now_onscreen_s2tt_res.endswith("!"):
if (
now_onscreen_s2tt_res.endswith(".")
or now_onscreen_s2tt_res.endswith("?")
or now_onscreen_s2tt_res.endswith("!")
):
now_onscreen_s2tt_res += " "
websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
now_onscreen_asr_res
)
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
now_onscreen_s2tt_res
)
else:
websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res + "\n\n"
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res + "\n\n"
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
now_onscreen_asr_res + "\n\n"
)
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
now_onscreen_s2tt_res + "\n\n"
)
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
@ -491,6 +585,8 @@ async def ws_serve(websocket, path):
async def async_vad(websocket, audio_in):
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " call vad function:")
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
# print(segments_result)
@ -506,7 +602,7 @@ async def async_vad(websocket, audio_in):
return speech_start, speech_end
if len(args.certfile) > 0:
if False:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions