streaming

This commit is contained in:
yangyexin.yyx 2024-08-28 14:20:14 +08:00
parent 1eb7507c24
commit 71b6ecbb39
3 changed files with 171 additions and 242 deletions

View File

@ -930,6 +930,7 @@ class LLMASR4(nn.Module):
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
llm_load_kwargs = llm_conf.get("load_kwargs", {})
if not llm_conf.get("low_cpu", False):
model = AutoModelForCausalLM.from_pretrained(
@ -937,6 +938,7 @@ class LLMASR4(nn.Module):
load_in_8bit=None,
device_map=None,
use_cache=None,
**llm_load_kwargs,
)
else:
import os
@ -947,6 +949,7 @@ class LLMASR4(nn.Module):
load_in_8bit=None,
device_map="cpu",
use_cache=None,
**llm_load_kwargs,
)
else:
llm_config = AutoConfig.from_pretrained(init_param_path)

View File

@ -22,17 +22,12 @@ parser.add_argument(
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk")
parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk")
parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk")
parser.add_argument("--chunk_interval", type=int, default=10, help="chunk")
parser.add_argument(
"--hotword",
type=str,
default="",
help="hotword file path, one hotword perline (e.g.:阿里巴巴 20)",
)
parser.add_argument("--chunk_interval", type=int, default=60, help="chunk")
parser.add_argument("--audio_in", type=str, default=None, help="audio_in")
parser.add_argument("--audio_fs", type=int, default=16000, help="audio_fs")
parser.add_argument("--asr_prompt", type=str, default="Copy:", help="asr prompt")
parser.add_argument("--s2tt_prompt", type=str, default="Translate the following sentence into English:", help="s2tt prompt")
parser.add_argument(
"--send_without_sleep",
action="store_true",
@ -41,10 +36,9 @@ parser.add_argument(
)
parser.add_argument("--thread_num", type=int, default=1, help="thread_num")
parser.add_argument("--words_max_print", type=int, default=10000, help="chunk")
parser.add_argument("--output_dir", type=str, default=None, help="output_dir")
parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl")
parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn")
parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass")
parser.add_argument("--mode", type=str, default="online", help="offline, online, 2pass")
args = parser.parse_args()
args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
@ -53,14 +47,6 @@ print(args)
from queue import Queue
voices = Queue()
offline_msg_done = False
if args.output_dir is not None:
# if os.path.exists(args.output_dir):
# os.remove(args.output_dir)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
async def record_microphone():
@ -80,41 +66,16 @@ async def record_microphone():
stream = p.open(
format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK
)
# hotwords
fst_dict = {}
hotword_msg = ""
if args.hotword.strip() != "":
if os.path.exists(args.hotword):
f_scp = open(args.hotword)
hot_lines = f_scp.readlines()
for line in hot_lines:
words = line.strip().split(" ")
if len(words) < 2:
print("Please checkout format of hotwords")
continue
try:
fst_dict[" ".join(words[:-1])] = int(words[-1])
except ValueError:
print("Please checkout format of hotwords")
hotword_msg = json.dumps(fst_dict)
else:
hotword_msg = args.hotword
use_itn = True
if args.use_itn == 0:
use_itn = False
message = json.dumps(
{
"mode": args.mode,
"chunk_size": args.chunk_size,
"chunk_interval": args.chunk_interval,
"encoder_chunk_look_back": args.encoder_chunk_look_back,
"decoder_chunk_look_back": args.decoder_chunk_look_back,
"wav_name": "microphone",
"is_speaking": True,
"hotwords": hotword_msg,
"itn": use_itn,
"asr_prompt": args.asr_prompt,
"s2tt_prompt": args.s2tt_prompt,
}
)
# voices.put(message)
@ -136,32 +97,8 @@ async def record_from_scp(chunk_begin, chunk_size):
else:
wavs = [args.audio_in]
# hotwords
fst_dict = {}
hotword_msg = ""
if args.hotword.strip() != "":
if os.path.exists(args.hotword):
f_scp = open(args.hotword)
hot_lines = f_scp.readlines()
for line in hot_lines:
words = line.strip().split(" ")
if len(words) < 2:
print("Please checkout format of hotwords")
continue
try:
fst_dict[" ".join(words[:-1])] = int(words[-1])
except ValueError:
print("Please checkout format of hotwords")
hotword_msg = json.dumps(fst_dict)
else:
hotword_msg = args.hotword
print(hotword_msg)
sample_rate = args.audio_fs
wav_format = "pcm"
use_itn = True
if args.use_itn == 0:
use_itn = False
if chunk_size > 0:
wavs = wavs[chunk_begin : chunk_begin + chunk_size]
@ -198,14 +135,10 @@ async def record_from_scp(chunk_begin, chunk_size):
"mode": args.mode,
"chunk_size": args.chunk_size,
"chunk_interval": args.chunk_interval,
"encoder_chunk_look_back": args.encoder_chunk_look_back,
"decoder_chunk_look_back": args.decoder_chunk_look_back,
"audio_fs": sample_rate,
"wav_name": wav_name,
"wav_format": wav_format,
"wav_name": "microphone",
"is_speaking": True,
"hotwords": hotword_msg,
"itn": use_itn,
"asr_prompt": args.asr_prompt,
"s2tt_prompt": args.s2tt_prompt,
}
)
@ -230,76 +163,24 @@ async def record_from_scp(chunk_begin, chunk_size):
await asyncio.sleep(sleep_duration)
if not args.mode == "offline":
await asyncio.sleep(2)
# offline model need to wait for message recved
if args.mode == "offline":
global offline_msg_done
while not offline_msg_done:
await asyncio.sleep(1)
await asyncio.sleep(2)
await websocket.close()
async def message(id):
global websocket, voices, offline_msg_done
global websocket, voices
text_print = ""
text_print_2pass_online = ""
text_print_2pass_offline = ""
if args.output_dir is not None:
ibest_writer = open(
os.path.join(args.output_dir, "text.{}".format(id)), "a", encoding="utf-8"
)
else:
ibest_writer = None
try:
while True:
meg = await websocket.recv()
meg = json.loads(meg)
wav_name = meg.get("wav_name", "demo")
text = meg["text"]
timestamp = ""
offline_msg_done = meg.get("is_final", False)
if "timestamp" in meg:
timestamp = meg["timestamp"]
asr_text = meg["asr_text"]
s2tt_text = meg["s2tt_text"]
if ibest_writer is not None:
if timestamp != "":
text_write_line = "{}\t{}\t{}\n".format(wav_name, text, timestamp)
else:
text_write_line = "{}\t{}\n".format(wav_name, text)
ibest_writer.write(text_write_line)
if "mode" not in meg:
continue
if meg["mode"] == "online":
text_print = text
os.system("clear")
print("\rpid" + str(id) + ": " + text_print)
elif meg["mode"] == "offline":
if timestamp != "":
text_print += "{} timestamp: {}".format(text, timestamp)
else:
text_print += "{}".format(text)
# text_print = text_print[-args.words_max_print:]
# os.system('clear')
print("\rpid" + str(id) + ": " + wav_name + ": " + text_print)
offline_msg_done = True
else:
if meg["mode"] == "2pass-online":
text_print_2pass_online += "{}".format(text)
text_print = text_print_2pass_offline + text_print_2pass_online
else:
text_print_2pass_online = ""
text_print = text_print_2pass_offline + "{}".format(text)
text_print_2pass_offline += "{}".format(text)
text_print = text_print[-args.words_max_print :]
os.system("clear")
print("\rpid" + str(id) + ": " + text_print)
# offline_msg_done=True
text_print = "\n\n" + "ASR: " + asr_text + "\n\n" + "S2TT: " + s2tt_text
os.system("clear")
print("\rpid" + str(id) + ": " + text_print)
except Exception as e:
print("Exception:", e)
@ -311,10 +192,9 @@ async def ws_client(id, chunk_begin, chunk_size):
if args.audio_in is None:
chunk_begin = 0
chunk_size = 1
global websocket, voices, offline_msg_done
global websocket, voices
for i in range(chunk_begin, chunk_begin + chunk_size):
offline_msg_done = False
voices = Queue()
if args.ssl == 1:
ssl_context = ssl.SSLContext()

View File

@ -1,38 +1,22 @@
import os
import asyncio
import json
import websockets
import time
import logging
import tracemalloc
import numpy as np
import argparse
import ssl
import os
import torch
import torchaudio
from transformers import TextIteratorStreamer
import numpy as np
from threading import Thread
import traceback
from transformers import TextIteratorStreamer
from funasr import AutoModel
from modelscope.hub.api import HubApi
from modelscope.hub.snapshot_download import snapshot_download
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument(
"--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model from modelscope",
)
parser.add_argument("--asr_model_revision", type=str, default="master", help="")
parser.add_argument(
"--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="model from modelscope",
)
parser.add_argument("--asr_model_online_revision", type=str, default="master", help="")
parser.add_argument(
"--vad_model",
type=str,
@ -50,7 +34,6 @@ parser.add_argument(
required=False,
help="certfile for ssl",
)
parser.add_argument(
"--keyfile",
type=str,
@ -63,9 +46,6 @@ args = parser.parse_args()
websocket_users = set()
print("model loading")
from funasr import AutoModel
# vad
model_vad = AutoModel(
model=args.vad_model,
@ -78,17 +58,11 @@ model_vad = AutoModel(
# chunk_size=60,
)
from funasr import AutoModel
from modelscope.hub.api import HubApi
api = HubApi()
if "key" in os.environ:
key = os.environ["key"]
api.login(key)
from modelscope.hub.snapshot_download import snapshot_download
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
@ -102,7 +76,6 @@ all_file_paths = [
llm_kwargs = {"num_beams": 1, "do_sample": False}
unfix_len = 5
max_streaming_res_onetime = 100
ckpt_dir = all_file_paths[0]
@ -114,7 +87,7 @@ model_llm = AutoModel(
llm_dtype="bf16",
max_length=1024,
llm_kwargs=llm_kwargs,
llm_conf={"init_param_path": llm_dir},
llm_conf={"init_param_path": llm_dir, "load_kwargs": {"attn_implementation": "eager"}},
tokenizer_conf={"init_param_path": llm_dir},
audio_encoder=audio_encoder_dir,
)
@ -125,6 +98,8 @@ tokenizer = model_llm.kwargs["tokenizer"]
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
print("model loaded! only support one client at the same time now!!!!")
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
@ -140,7 +115,7 @@ def load_bytes(input):
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
async def streaming_transcribe(websocket, audio_in, his_state=None, prompt=None):
async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None):
if his_state is None:
his_state = model_dict
model = his_state["model"]
@ -148,11 +123,19 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, prompt=None)
if websocket.streaming_state is None:
previous_asr_text = ""
previous_s2tt_text = ""
previous_vad_onscreen_asr_text = ""
previous_vad_onscreen_s2tt_text = ""
else:
previous_asr_text = websocket.streaming_state["previous_asr_text"]
previous_asr_text = websocket.streaming_state.get("previous_asr_text", "")
previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "")
previous_vad_onscreen_asr_text = websocket.streaming_state.get("previous_vad_onscreen_asr_text", "")
previous_vad_onscreen_s2tt_text = websocket.streaming_state.get("previous_vad_onscreen_s2tt_text", "")
if prompt is None:
prompt = "Copy:"
if asr_prompt is None or asr_prompt == "":
asr_prompt = "Copy:"
if s2tt_prompt is None or s2tt_prompt == "":
s2tt_prompt = "Translate the following sentence into English:"
audio_seconds = load_bytes(audio_in).shape[0] / 16000
print(f"Streaming audio length: {audio_seconds} seconds")
@ -160,74 +143,128 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, prompt=None)
asr_content = []
system_prompt = "You are a helpful assistant."
asr_content.append({"role": "system", "content": system_prompt})
s2tt_content = []
system_prompt = "You are a helpful assistant."
s2tt_content.append({"role": "system", "content": system_prompt})
asr_user_prompt = f"{prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_asr_text}"
user_asr_prompt = f"{asr_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_asr_text}"
user_s2tt_prompt = f"{s2tt_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_s2tt_text}"
asr_content.append({"role": "user", "content": asr_user_prompt, "audio": audio_in})
asr_content.append({"role": "user", "content": user_asr_prompt, "audio": audio_in})
asr_content.append({"role": "assistant", "content": "target_out"})
s2tt_content.append({"role": "user", "content": user_s2tt_prompt, "audio": audio_in})
s2tt_content.append({"role": "assistant", "content": "target_out"})
streaming_asr_time_beg = time.time()
asr_inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
streaming_time_beg = time.time()
inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[asr_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
)
asr_model_inputs = {}
asr_model_inputs["inputs_embeds"] = asr_inputs_embeds
model_asr_inputs = {}
model_asr_inputs["inputs_embeds"] = inputs_asr_embeds
inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[s2tt_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
)
model_s2tt_inputs = {}
model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds
print("previous_asr_text:", previous_asr_text)
print("previous_s2tt_text:", previous_s2tt_text)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(asr_model_inputs, streamer=streamer, max_new_tokens=1024)
thread = Thread(target=model.llm.generate, kwargs=generation_kwargs)
thread.start()
asr_streamer = TextIteratorStreamer(tokenizer)
asr_generation_kwargs = dict(model_asr_inputs, streamer=asr_streamer, max_new_tokens=1024)
asr_generation_kwargs.update(llm_kwargs)
asr_thread = Thread(target=model.llm.generate, kwargs=asr_generation_kwargs)
asr_thread.start()
s2tt_streamer = TextIteratorStreamer(tokenizer)
s2tt_generation_kwargs = dict(model_s2tt_inputs, streamer=s2tt_streamer, max_new_tokens=1024)
s2tt_generation_kwargs.update(llm_kwargs)
s2tt_thread = Thread(target=model.llm.generate, kwargs=s2tt_generation_kwargs)
s2tt_thread.start()
onscreen_asr_res = previous_asr_text
beg_llm = time.time()
for new_text in streamer:
end_llm = time.time()
print(
f"generated new text {new_text}, time_llm_decode: {end_llm - beg_llm:.2f}"
)
if len(new_text) > 0:
onscreen_asr_res += new_text.replace("<|im_end|>", "")
onscreen_s2tt_res = previous_s2tt_text
mode = "online"
for new_asr_text in asr_streamer:
print(f"generated new asr text {new_asr_text}")
if len(new_asr_text) > 0:
onscreen_asr_res += new_asr_text.replace("<|im_end|>", "")
new_s2tt_text = next(s2tt_streamer)
print(f"generated new s2tt text {new_s2tt_text}")
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
if len(new_asr_text) > 0 or len(new_s2tt_text) > 0:
message = json.dumps(
{
"mode": mode,
"text": onscreen_asr_res,
"mode": "online",
"asr_text": previous_vad_onscreen_asr_text + onscreen_asr_res,
"s2tt_text": previous_vad_onscreen_s2tt_text + onscreen_s2tt_res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
streaming_asr_time_end = time.time()
print(f"Streaming ASR inference time: {streaming_asr_time_end - streaming_asr_time_beg}")
for new_s2tt_text in s2tt_streamer:
print(f"generated new s2tt text {new_s2tt_text}")
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
if len(new_s2tt_text) > 0:
message = json.dumps(
{
"mode": "online",
"asr_text": previous_vad_onscreen_asr_text + onscreen_asr_res,
"s2tt_text": previous_vad_onscreen_s2tt_text + onscreen_s2tt_res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
streaming_time_end = time.time()
print(f"Streaming inference time: {streaming_time_end - streaming_time_beg}")
asr_text_len = len(tokenizer.encode(onscreen_asr_res))
s2tt_text_len = len(tokenizer.encode(onscreen_s2tt_res))
if asr_text_len > unfix_len and audio_seconds > 1.1:
if asr_text_len <= max_streaming_res_onetime:
previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-unfix_len])
else:
onscreen_asr_res = previous_asr_text
previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-unfix_len])
else:
previous_asr_text = ""
if s2tt_text_len > unfix_len and audio_seconds > 1.1:
previous_s2tt_text = tokenizer.decode(tokenizer.encode(onscreen_s2tt_res)[:-unfix_len])
else:
previous_s2tt_text = ""
websocket.streaming_state = {}
websocket.streaming_state["previous_asr_text"] = previous_asr_text
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
websocket.streaming_state["previous_s2tt_text"] = previous_s2tt_text
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
print("fix asr part:", previous_asr_text)
print("model loaded! only support one client at the same time now!!!!")
print("fix s2tt part:", previous_s2tt_text)
async def ws_reset(websocket):
print("ws reset now, total num is ", len(websocket_users))
websocket.status_dict_asr_online["cache"] = {}
websocket.status_dict_asr_online["is_final"] = True
websocket.streaming_state = None
websocket.streaming_state = {}
websocket.streaming_state["is_final"] = True
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
websocket.streaming_state["onscreen_asr_res"] = ""
websocket.streaming_state["onscreen_s2tt_res"] = ""
websocket.streaming_state["previous_vad_onscreen_asr_text"] = ""
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ""
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
@ -246,8 +283,15 @@ async def ws_serve(websocket, path):
global websocket_users
# await clear_websocket()
websocket_users.add(websocket)
websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
websocket.streaming_state = {
"previous_asr_text": "",
"previous_s2tt_text": "",
"onscreen_asr_res": "",
"onscreen_s2tt_res": "",
"previous_vad_onscreen_asr_text": "",
"previous_vad_onscreen_s2tt_text": "",
"is_final": False,
}
websocket.status_dict_vad = {"cache": {}, "is_final": False}
websocket.chunk_interval = 10
@ -255,8 +299,6 @@ async def ws_serve(websocket, path):
speech_start = False
speech_end_i = -1
websocket.wav_name = "microphone"
websocket.mode = "online"
websocket.streaming_state = None
print("new user connected", flush=True)
try:
@ -266,7 +308,9 @@ async def ws_serve(websocket, path):
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
websocket.streaming_state["is_final"] = not websocket.is_speaking
if not messagejson["is_speaking"]:
await clear_websocket()
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
@ -275,22 +319,18 @@ async def ws_serve(websocket, path):
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson[
"encoder_chunk_look_back"
]
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson[
"decoder_chunk_look_back"
]
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
websocket.mode = messagejson["mode"]
chunk_size = [int(x) for x in chunk_size]
if "asr_prompt" in messagejson:
asr_prompt = messagejson["asr_prompt"]
else:
asr_prompt = "Copy:"
if "s2tt_prompt" in messagejson:
s2tt_prompt = messagejson["s2tt_prompt"]
else:
s2tt_prompt = "Translate the following sentence into English:"
websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online["chunk_size"][1] * 60 / websocket.chunk_interval
chunk_size[1] * 60 / websocket.chunk_interval
)
if len(frames_asr) > 0 or not isinstance(message, str):
if not isinstance(message, str):
@ -299,20 +339,21 @@ async def ws_serve(websocket, path):
websocket.vad_pre_idx += duration_ms
# asr online
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
websocket.streaming_state["is_final"] = speech_end_i != -1
if (
(len(frames_asr) % websocket.chunk_interval == 0
or websocket.status_dict_asr_online["is_final"])
or websocket.streaming_state["is_final"])
and len(frames_asr) != 0
):
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr)
try:
await streaming_transcribe(websocket, audio_in)
except:
print(f"error in asr streaming, {websocket.status_dict_asr_online}")
audio_in = b"".join(frames_asr)
try:
await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt)
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
if speech_start:
frames_asr.append(message)
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
@ -324,17 +365,22 @@ async def ws_serve(websocket, path):
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
# vad end
if speech_end_i != -1 or not websocket.is_speaking:
audio_in = b"".join(frames_asr)
frames_asr = []
speech_start = False
websocket.status_dict_asr_online["cache"] = {}
websocket.streaming_state = None
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
websocket.streaming_state["previous_vad_onscreen_asr_text"] = websocket.streaming_state.get("onscreen_asr_res", "") + "\n\n"
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = websocket.streaming_state.get("onscreen_s2tt_res", "") + "\n\n"
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
websocket.streaming_state = None
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
else:
frames = frames[-20:]
else: