Merge branch 'dev_gzf_deepspeed' of gitlab.alibaba-inc.com:zhifu.gzf/FunASR into dev_gzf_deepspeed

merge
This commit is contained in:
游雁 2024-09-12 17:50:48 +08:00
commit f4b5af8473
3 changed files with 335 additions and 179 deletions

View File

@ -1339,18 +1339,26 @@ class LLMASR4(nn.Module):
# audio encoder
speech = batch["speech"]
if len(speech) > 0:
speech_lengths = batch["speech_lengths"][:, 0]
# fp16
if kwargs.get("fp16", False):
speech = speech.to(torch.float16)
elif kwargs.get("bf16", False):
speech = speech.to(torch.bfloat16)
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if "audio_embedding" in kwargs and "audio_embedding_lens" in kwargs:
encoder_out = kwargs["audio_embedding"]
encoder_out_lens = kwargs["audio_embedding_lens"]
else:
speech_lengths = batch["speech_lengths"][:, 0]
# fp16
if kwargs.get("fp16", False):
speech = speech.to(torch.float16)
elif kwargs.get("bf16", False):
speech = speech.to(torch.bfloat16)
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
# audio_adaptor
encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
meta_data["audio_adaptor_out"] = encoder_out
meta_data["audio_adaptor_out_lens"] = encoder_out_lens
input_ids = batch["input_ids"]
source_ids = batch["source_ids"]

View File

@ -12,6 +12,9 @@ from transformers import TextIteratorStreamer
from funasr import AutoModel
from modelscope.hub.api import HubApi
from modelscope.hub.snapshot_download import snapshot_download
import torch
import traceback
import re
parser = argparse.ArgumentParser()
parser.add_argument(
@ -59,7 +62,7 @@ model_vad = AutoModel(
disable_pbar=True,
disable_log=True,
speech_noise_thres=0.4,
max_single_segment_time=30000,
max_single_segment_time=35000,
max_end_silence_time=800,
# chunk_size=60,
)
@ -79,7 +82,8 @@ audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision
# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
device = "cuda:0"
all_file_paths = [
# "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240910_streaming",
# "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240912_streaming",
"FunAudioLLM/qwen2_7b_mmt_v15_20240912_streaming",
"FunAudioLLM/qwen2_7b_mmt_v15_20240910_streaming",
"FunAudioLLM/qwen2_7b_mmt_v15_20240902",
"FunAudioLLM/qwen2_7b_mmt_v14_20240830",
@ -93,6 +97,7 @@ llm_kwargs = {"num_beams": 1, "do_sample": False, "repetition_penalty": 1.3}
UNFIX_LEN = 5
MIN_LEN_PER_PARAGRAPH = 25
MIN_LEN_SEC_AUDIO_FIX = 1.1
DO_ASR_FRAME_INTERVAL = 12
ckpt_dir = all_file_paths[0]
@ -134,6 +139,12 @@ def load_bytes(input):
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
def is_chinese_ending(s):
return re.search(r'[\u4e00-\u9fff]$', s) is not None
def is_alpha_ending(s):
return re.search(r'[a-zA-Z]$', s) is not None
async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=None, asr_prompt=None, s2tt_prompt=None):
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " call streaming_transcribe function:")
@ -147,6 +158,11 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
previous_s2tt_text = ""
previous_vad_onscreen_asr_text = ""
previous_vad_onscreen_s2tt_text = ""
# concat_asr_text = []
# concat_s2tt_text = []
# concat_audio = []
# concat_audio_embedding = []
# concat_audio_embedding_lens = []
else:
previous_asr_text = websocket.streaming_state.get("previous_asr_text", "")
previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "")
@ -156,13 +172,19 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
previous_vad_onscreen_s2tt_text = websocket.streaming_state.get(
"previous_vad_onscreen_s2tt_text", ""
)
# concat_asr_text = websocket.streaming_state.get("concat_asr_text", [])
# concat_s2tt_text = websocket.streaming_state.get("concat_s2tt_text", [])
# concat_audio = websocket.streaming_state.get("concat_audio", [])
# concat_audio_embedding = websocket.streaming_state.get("concat_audio_embedding", [])
# concat_audio_embedding_lens = websocket.streaming_state.get("concat_audio_embedding_lens", [])
if asr_prompt is None or asr_prompt == "":
asr_prompt = "Speech transcription:"
if s2tt_prompt is None or s2tt_prompt == "":
s2tt_prompt = "Translate into English:"
audio_seconds = load_bytes(audio_in).shape[0] / 16000
audio_seconds = len(audio_in) // 32 / 1000
cur_audio = audio_in
print(f"Streaming audio length: {audio_seconds} seconds")
asr_content = []
@ -190,6 +212,37 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
device=device,
infer_with_assistant_input=True,
)
cur_audio_embedding, cur_audio_embedding_lens = meta_data["audio_adaptor_out"], meta_data["audio_adaptor_out_lens"]
# if not args.return_sentence and len(concat_audio_embedding) != 0:
if False:
audio_embedding = torch.cat([concat_audio_embedding[-1], cur_audio_embedding], dim=1)
audio_embedding_lens = concat_audio_embedding_lens[-1] + cur_audio_embedding_lens
actual_prev_asr_text = concat_asr_text[-1] + previous_asr_text
actual_prev_s2tt_text = concat_s2tt_text[-1] + previous_s2tt_text
actual_audio = concat_audio[-1] + cur_audio
user_asr_prompt = f"{asr_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{actual_prev_asr_text}"
user_s2tt_prompt = f"{s2tt_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{actual_prev_s2tt_text}"
asr_content[1] = {"role": "user", "content": user_asr_prompt, "audio": actual_audio}
s2tt_content[1] = {"role": "user", "content": user_s2tt_prompt, "audio": actual_audio}
inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[asr_content],
None,
"test_demo",
tokenizer,
frontend,
device=device,
infer_with_assistant_input=True,
audio_embedding=audio_embedding,
audio_embedding_lens=audio_embedding_lens,
)
else:
audio_embedding = cur_audio_embedding
audio_embedding_lens = cur_audio_embedding_lens
model_asr_inputs = {}
model_asr_inputs["inputs_embeds"] = inputs_asr_embeds
inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
@ -200,6 +253,8 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
frontend,
device=device,
infer_with_assistant_input=True,
audio_embedding=audio_embedding,
audio_embedding_lens=audio_embedding_lens,
)
model_s2tt_inputs = {}
model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds
@ -324,6 +379,33 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
)
if is_vad_end:
# concat_asr_text.append(onscreen_asr_res)
# concat_s2tt_text.append(onscreen_s2tt_res)
# concat_audio.append(cur_audio)
# concat_audio_embedding.append(cur_audio_embedding)
# concat_audio_embedding_lens.append(cur_audio_embedding_lens)
# websocket.streaming_state["concat_asr_text"] = concat_asr_text
# websocket.streaming_state["concat_s2tt_text"] = concat_s2tt_text
# websocket.streaming_state["concat_audio"] = concat_audio
# websocket.streaming_state["concat_audio_embedding"] = concat_audio_embedding
# websocket.streaming_state["concat_audio_embedding_lens"] = concat_audio_embedding_lens
clean_return_asr_res = return_asr_res.replace("<em>", "").replace("</em>", "")
clean_return_s2tt_res = return_s2tt_res.replace("<em>", "").replace("</em>", "")
if is_alpha_ending(clean_return_asr_res):
return_asr_res = clean_return_asr_res + ".<em></em>"
onscreen_asr_res += "."
elif is_chinese_ending(clean_return_asr_res):
return_asr_res = clean_return_asr_res + "。<em></em>"
onscreen_asr_res += ""
if is_alpha_ending(clean_return_s2tt_res):
return_s2tt_res = clean_return_s2tt_res + ".<em></em>"
onscreen_s2tt_res += "."
elif is_chinese_ending(clean_return_s2tt_res):
return_s2tt_res = clean_return_s2tt_res + "。<em></em>"
onscreen_s2tt_res += ""
message = json.dumps(
{
"mode": "online",
@ -385,7 +467,11 @@ async def ws_reset(websocket):
websocket.streaming_state["onscreen_s2tt_res"] = ""
websocket.streaming_state["previous_vad_onscreen_asr_text"] = ""
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ""
# websocket.streaming_state["concat_asr_text"] = []
# websocket.streaming_state["concat_s2tt_text"] = []
# websocket.streaming_state["concat_audio"] = []
# websocket.streaming_state["concat_audio_embedding"] = []
# websocket.streaming_state["concat_audio_embedding_lens"] = []
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
@ -411,6 +497,11 @@ async def ws_serve(websocket, path):
"onscreen_s2tt_res": "",
"previous_vad_onscreen_asr_text": "",
"previous_vad_onscreen_s2tt_text": "",
# "concat_asr_text": [],
# "concat_s2tt_text": [],
# "concat_audio": [],
# "concat_audio_embedding": [],
# "concat_audio_embedding_lens": [],
"is_final": False,
}
websocket.status_dict_vad = {"cache": {}, "is_final": False}
@ -469,7 +560,7 @@ async def ws_serve(websocket, path):
# asr online
websocket.streaming_state["is_final"] = speech_end_i != -1
if (
len(frames_asr) % websocket.chunk_interval == 0
len(frames_asr) % DO_ASR_FRAME_INTERVAL == 0
or websocket.streaming_state["is_final"]
) and len(frames_asr) != 0:
audio_in = b"".join(frames_asr)
@ -480,6 +571,7 @@ async def ws_serve(websocket, path):
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
traceback.print_exc()
if speech_start:
frames_asr.append(message)
@ -487,8 +579,9 @@ async def ws_serve(websocket, path):
if not args.no_vad:
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
except:
print("error in vad")
except Exception as e:
print(f"error in vad, {e}")
traceback.print_exc()
if speech_start_i != -1:
speech_start = True
speech_end_i = -1
@ -513,6 +606,7 @@ async def ws_serve(websocket, path):
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
traceback.print_exc()
frames_asr = []
speech_start = False
websocket.streaming_state["previous_asr_text"] = ""
@ -573,6 +667,13 @@ async def ws_serve(websocket, path):
websocket.status_dict_vad["cache"] = {}
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
websocket.streaming_state["onscreen_asr_res"] = ""
websocket.streaming_state["onscreen_s2tt_res"] = ""
# websocket.streaming_state["concat_asr_text"] = []
# websocket.streaming_state["concat_s2tt_text"] = []
# websocket.streaming_state["concat_audio"] = []
# websocket.streaming_state["concat_audio_embedding"] = []
# websocket.streaming_state["concat_audio_embedding_lens"] = []
else:
frames = frames[-20:]
else:

View File

@ -12,14 +12,8 @@ from transformers import TextIteratorStreamer
from funasr import AutoModel
from modelscope.hub.api import HubApi
from modelscope.hub.snapshot_download import snapshot_download
UNFIX_LEN = 5
MIN_LEN_PER_PARAGRAPH = 25
MIN_SEC_AUDIO_FIX = 1.1
MAX_ITER_PER_CHUNK = 20
VAD_SEG_LOOKBACK_FRAME = 2
VAD_SEG_LOOKAHEAD_FRAME = 6
MAX_SEC_AUDIO_HISTORY = 40
import torch
import traceback
parser = argparse.ArgumentParser()
parser.add_argument(
@ -36,6 +30,8 @@ parser.add_argument("--vad_model_revision", type=str, default="master", help="")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
parser.add_argument("--return_sentence", action="store_true", help="return sentence or all_res")
parser.add_argument("--no_vad", action="store_true", help="infer without vad")
parser.add_argument(
"--certfile",
type=str,
@ -64,9 +60,9 @@ model_vad = AutoModel(
device=args.device,
disable_pbar=True,
disable_log=True,
# speech_noise_thres=0.3,
# max_single_segment_time=40000,
# max_end_silence_time=800,
speech_noise_thres=0.4,
max_single_segment_time=30000,
max_end_silence_time=800,
# chunk_size=60,
)
@ -85,7 +81,8 @@ audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision
# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
device = "cuda:0"
all_file_paths = [
# "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240902",
# "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240910_streaming",
"FunAudioLLM/qwen2_7b_mmt_v15_20240910_streaming",
"FunAudioLLM/qwen2_7b_mmt_v15_20240902",
"FunAudioLLM/qwen2_7b_mmt_v14_20240830",
"FunAudioLLM/audiolm_v11_20240807",
@ -94,7 +91,11 @@ all_file_paths = [
"FunAudioLLM/Speech2Text_Align_V0628",
]
llm_kwargs = {"num_beams": 1, "do_sample": False}
llm_kwargs = {"num_beams": 1, "do_sample": False, "repetition_penalty": 1.3}
UNFIX_LEN = 5
MIN_LEN_PER_PARAGRAPH = 25
MIN_LEN_SEC_AUDIO_FIX = 1.1
DO_ASR_FRAME_INTERVAL = 12
ckpt_dir = all_file_paths[0]
@ -121,11 +122,6 @@ model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
print("model loaded! only support one client at the same time now!!!!")
def remove_suffix(s, suffix):
if s.endswith(suffix):
return s[:-len(suffix)]
return s
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
@ -142,8 +138,8 @@ def load_bytes(input):
return array
async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=None, asr_prompt=None, s2tt_prompt=None):
# current_time = datetime.now()
# print("DEBUG:" + str(current_time) + " call streaming_transcribe function:")
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " call streaming_transcribe function:")
if his_state is None:
his_state = model_dict
model = his_state["model"]
@ -154,7 +150,11 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
previous_s2tt_text = ""
previous_vad_onscreen_asr_text = ""
previous_vad_onscreen_s2tt_text = ""
previous_vad_audio = []
concat_asr_text = []
concat_s2tt_text = []
concat_audio = []
concat_audio_embedding = []
concat_audio_embedding_lens = []
else:
previous_asr_text = websocket.streaming_state.get("previous_asr_text", "")
previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "")
@ -164,40 +164,21 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
previous_vad_onscreen_s2tt_text = websocket.streaming_state.get(
"previous_vad_onscreen_s2tt_text", ""
)
previous_vad_audio = websocket.streaming_state.get(
"previous_vad_audio", []
)
concat_asr_text = websocket.streaming_state.get("concat_asr_text", [])
concat_s2tt_text = websocket.streaming_state.get("concat_s2tt_text", [])
concat_audio = websocket.streaming_state.get("concat_audio", [])
concat_audio_embedding = websocket.streaming_state.get("concat_audio_embedding", [])
concat_audio_embedding_lens = websocket.streaming_state.get("concat_audio_embedding_lens", [])
if asr_prompt is None or asr_prompt == "":
asr_prompt = "Speech transcription:"
if s2tt_prompt is None or s2tt_prompt == "":
s2tt_prompt = "Translate into English:"
# audio_seconds = load_bytes(audio_in).shape[0] / 16000
audio_seconds = len(audio_in) // 32 / 1000
cur_audio = audio_in
print(f"Streaming audio length: {audio_seconds} seconds")
audio_in_prev_vad = b""
asr_text_prev_vad = ""
s2tt_text_prev_vad = ""
total_audio_seconds = audio_seconds
for i, audio_seg in enumerate(previous_vad_audio[::-1]):
cur_audio_seg_len = len(audio_seg) // 32 / 1000
if total_audio_seconds + cur_audio_seg_len <= MAX_SEC_AUDIO_HISTORY:
total_audio_seconds += cur_audio_seg_len
audio_in_prev_vad = b"".join([audio_seg, audio_in_prev_vad])
asr_text_prev_vad_seg = remove_suffix(previous_vad_onscreen_asr_text, "<vad>").split("<vad>")[-(i + 1)].replace("\n", "")
if asr_text_prev_vad_seg.endswith(".") or asr_text_prev_vad_seg.endswith(",") or asr_text_prev_vad_seg.endswith("?") or asr_text_prev_vad_seg.endswith("!"):
asr_text_prev_vad_seg = asr_text_prev_vad_seg + " "
s2tt_text_prev_vad_seg = remove_suffix(previous_vad_onscreen_s2tt_text, "<vad>").split("<vad>")[-(i + 1)].replace("\n", "")
if s2tt_text_prev_vad_seg.endswith(".") or s2tt_text_prev_vad_seg.endswith(",") or s2tt_text_prev_vad_seg.endswith("?") or s2tt_text_prev_vad_seg.endswith("!"):
s2tt_text_prev_vad_seg = s2tt_text_prev_vad_seg + " "
asr_text_prev_vad = asr_text_prev_vad_seg + asr_text_prev_vad
s2tt_text_prev_vad = s2tt_text_prev_vad_seg + s2tt_text_prev_vad
else:
websocket.streaming_state["previous_vad_audio"] = previous_vad_audio[-(i + 1):]
break
asr_content = []
system_prompt = "You are a helpful assistant."
asr_content.append({"role": "system", "content": system_prompt})
@ -205,12 +186,12 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
system_prompt = "You are a helpful assistant."
s2tt_content.append({"role": "system", "content": system_prompt})
user_asr_prompt = f"{asr_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{asr_text_prev_vad + previous_asr_text}"
user_s2tt_prompt = f"{s2tt_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{s2tt_text_prev_vad + previous_s2tt_text}"
user_asr_prompt = f"{asr_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_asr_text}"
user_s2tt_prompt = f"{s2tt_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_s2tt_text}"
asr_content.append({"role": "user", "content": user_asr_prompt, "audio": b"".join([audio_in_prev_vad, audio_in])})
asr_content.append({"role": "user", "content": user_asr_prompt, "audio": audio_in})
asr_content.append({"role": "assistant", "content": "target_out"})
s2tt_content.append({"role": "user", "content": user_s2tt_prompt, "audio": b"".join([audio_in_prev_vad, audio_in])})
s2tt_content.append({"role": "user", "content": user_s2tt_prompt, "audio": audio_in})
s2tt_content.append({"role": "assistant", "content": "target_out"})
streaming_time_beg = time.time()
@ -223,6 +204,36 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
device=device,
infer_with_assistant_input=True,
)
cur_audio_embedding, cur_audio_embedding_lens = meta_data["audio_adaptor_out"], meta_data["audio_adaptor_out_lens"]
if not args.return_sentence and len(concat_audio_embedding) != 0:
audio_embedding = torch.cat([concat_audio_embedding[-1], cur_audio_embedding], dim=1)
audio_embedding_lens = concat_audio_embedding_lens[-1] + cur_audio_embedding_lens
actual_prev_asr_text = concat_asr_text[-1] + previous_asr_text
actual_prev_s2tt_text = concat_s2tt_text[-1] + previous_s2tt_text
actual_audio = concat_audio[-1] + cur_audio
user_asr_prompt = f"{asr_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{actual_prev_asr_text}"
user_s2tt_prompt = f"{s2tt_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{actual_prev_s2tt_text}"
asr_content[1] = {"role": "user", "content": user_asr_prompt, "audio": actual_audio}
s2tt_content[1] = {"role": "user", "content": user_s2tt_prompt, "audio": actual_audio}
inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[asr_content],
None,
"test_demo",
tokenizer,
frontend,
device=device,
infer_with_assistant_input=True,
audio_embedding=audio_embedding,
audio_embedding_lens=audio_embedding_lens,
)
else:
audio_embedding = cur_audio_embedding
audio_embedding_lens = cur_audio_embedding_lens
model_asr_inputs = {}
model_asr_inputs["inputs_embeds"] = inputs_asr_embeds
inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
@ -233,14 +244,14 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
frontend,
device=device,
infer_with_assistant_input=True,
audio_embedding=audio_embedding,
audio_embedding_lens=audio_embedding_lens,
)
model_s2tt_inputs = {}
model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds
print("previous_asr_text:", previous_asr_text)
print("previous_s2tt_text:", previous_s2tt_text)
print("actual feed previous asr part:", asr_text_prev_vad + previous_asr_text)
print("actual feed previous s2tt part:", s2tt_text_prev_vad + previous_s2tt_text)
asr_streamer = TextIteratorStreamer(tokenizer)
asr_generation_kwargs = dict(model_asr_inputs, streamer=asr_streamer, max_new_tokens=1024)
@ -258,35 +269,26 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
remain_s2tt_text = True
asr_iter_cnt = 0
s2tt_iter_cnt = 0
is_asr_repetition = False
is_s2tt_repetition = False
for new_asr_text in asr_streamer:
# current_time = datetime.now()
# print("DEBUG: " + str(current_time) + " " + f"generated new asr text {new_asr_text}")
current_time = datetime.now()
print("DEBUG: " + str(current_time) + " " + f"generated new asr text {new_asr_text}")
if len(new_asr_text) > 0:
onscreen_asr_res += new_asr_text.replace("<|im_end|>", "")
if len(new_asr_text.replace("<|im_end|>", "")) > 0:
asr_iter_cnt += 1
if asr_iter_cnt > MAX_ITER_PER_CHUNK:
is_asr_repetition = True
break
if remain_s2tt_text:
try:
new_s2tt_text = next(s2tt_streamer)
# current_time = datetime.now()
# print(
# "DEBUG: "
# + str(current_time)
# + " "
# + f"generated new s2tt text {new_s2tt_text}"
# )
s2tt_iter_cnt += 1
current_time = datetime.now()
print(
"DEBUG: "
+ str(current_time)
+ " "
+ f"generated new s2tt text {new_s2tt_text}"
)
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
except StopIteration:
new_s2tt_text = ""
remain_s2tt_text = False
@ -307,8 +309,6 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
return_s2tt_res = fix_s2tt_part + "<em>" + unfix_s2tt_part + "</em>"
else:
return_s2tt_res = fix_s2tt_part + unfix_s2tt_part + "<em></em>"
return_asr_res = return_asr_res.replace("<vad>", "")
return_s2tt_res = return_s2tt_res.replace("<vad>", "")
message = json.dumps(
{
"mode": "online",
@ -316,6 +316,7 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
"s2tt_text": return_s2tt_res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
"is_sentence_end": False,
}
)
await websocket.send(message)
@ -328,17 +329,12 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
if remain_s2tt_text:
for new_s2tt_text in s2tt_streamer:
# current_time = datetime.now()
# print(
# "DEBUG: " + str(current_time) + " " + f"generated new s2tt text {new_s2tt_text}"
# )
current_time = datetime.now()
print(
"DEBUG: " + str(current_time) + " " + f"generated new s2tt text {new_s2tt_text}"
)
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
if len(new_s2tt_text.replace("<|im_end|>", "")) > 0:
s2tt_iter_cnt += 1
if s2tt_iter_cnt > MAX_ITER_PER_CHUNK:
is_s2tt_repetition = True
break
if len(new_s2tt_text) > 0:
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
@ -355,8 +351,6 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
return_s2tt_res = fix_s2tt_part + "<em>" + unfix_s2tt_part + "</em>"
else:
return_s2tt_res = fix_s2tt_part + unfix_s2tt_part + "<em></em>"
return_asr_res = return_asr_res.replace("<vad>", "")
return_s2tt_res = return_s2tt_res.replace("<vad>", "")
message = json.dumps(
{
"mode": "online",
@ -364,6 +358,7 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
"s2tt_text": return_s2tt_res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
"is_sentence_end": False,
}
)
await websocket.send(message)
@ -374,36 +369,52 @@ async def streaming_transcribe(websocket, audio_in, is_vad_end=False, his_state=
previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
)
if is_vad_end:
concat_asr_text.append(onscreen_asr_res)
concat_s2tt_text.append(onscreen_s2tt_res)
concat_audio.append(cur_audio)
concat_audio_embedding.append(cur_audio_embedding)
concat_audio_embedding_lens.append(cur_audio_embedding_lens)
websocket.streaming_state["concat_asr_text"] = concat_asr_text
websocket.streaming_state["concat_s2tt_text"] = concat_s2tt_text
websocket.streaming_state["concat_audio"] = concat_audio
websocket.streaming_state["concat_audio_embedding"] = concat_audio_embedding
websocket.streaming_state["concat_audio_embedding_lens"] = concat_audio_embedding_lens
message = json.dumps(
{
"mode": "online",
"asr_text": return_asr_res,
"s2tt_text": return_s2tt_res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
"is_sentence_end": True,
}
)
await websocket.send(message)
streaming_time_end = time.time()
print(f"Streaming inference time: {streaming_time_end - streaming_time_beg}")
asr_text_len = len(tokenizer.encode(onscreen_asr_res))
s2tt_text_len = len(tokenizer.encode(onscreen_s2tt_res))
if asr_text_len > UNFIX_LEN and audio_seconds > MIN_SEC_AUDIO_FIX and not is_asr_repetition:
if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX:
pre_previous_asr_text = previous_asr_text
previous_asr_text = tokenizer.decode(
tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN]
).replace("<EFBFBD>", "")
if len(previous_asr_text) <= len(pre_previous_asr_text):
previous_asr_text = pre_previous_asr_text
elif is_asr_repetition:
pass
else:
previous_asr_text = ""
if (
s2tt_text_len > UNFIX_LEN
and audio_seconds > MIN_SEC_AUDIO_FIX
and not is_s2tt_repetition
):
if s2tt_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX:
pre_previous_s2tt_text = previous_s2tt_text
previous_s2tt_text = tokenizer.decode(
tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN]
).replace("<EFBFBD>", "")
if len(previous_s2tt_text) <= len(pre_previous_s2tt_text):
previous_s2tt_text = pre_previous_s2tt_text
elif is_s2tt_repetition:
pass
else:
previous_s2tt_text = ""
@ -431,8 +442,11 @@ async def ws_reset(websocket):
websocket.streaming_state["onscreen_s2tt_res"] = ""
websocket.streaming_state["previous_vad_onscreen_asr_text"] = ""
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ""
websocket.streaming_state["previous_vad_audio"] = []
websocket.streaming_state["concat_asr_text"] = []
websocket.streaming_state["concat_s2tt_text"] = []
websocket.streaming_state["concat_audio"] = []
websocket.streaming_state["concat_audio_embedding"] = []
websocket.streaming_state["concat_audio_embedding_lens"] = []
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
@ -458,7 +472,11 @@ async def ws_serve(websocket, path):
"onscreen_s2tt_res": "",
"previous_vad_onscreen_asr_text": "",
"previous_vad_onscreen_s2tt_text": "",
"previous_vad_audio": [],
"concat_asr_text": [],
"concat_s2tt_text": [],
"concat_audio": [],
"concat_audio_embedding": [],
"concat_audio_embedding_lens": [],
"is_final": False,
}
websocket.status_dict_vad = {"cache": {}, "is_final": False}
@ -472,12 +490,12 @@ async def ws_serve(websocket, path):
try:
async for message in websocket:
# if isinstance(message, str):
# current_time = datetime.now()
# print("DEBUG:" + str(current_time) + " received message:", message)
# else:
# current_time = datetime.now()
# print("DEBUG:" + str(current_time) + " received audio bytes:")
if isinstance(message, str):
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " received message:", message)
else:
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " received audio bytes:")
if isinstance(message, str):
messagejson = json.loads(message)
@ -485,8 +503,8 @@ async def ws_serve(websocket, path):
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.streaming_state["is_final"] = not websocket.is_speaking
if not messagejson["is_speaking"]:
await clear_websocket()
# if not messagejson["is_speaking"]:
# await clear_websocket()
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
@ -517,7 +535,7 @@ async def ws_serve(websocket, path):
# asr online
websocket.streaming_state["is_final"] = speech_end_i != -1
if (
len(frames_asr) % websocket.chunk_interval == 0
len(frames_asr) % DO_ASR_FRAME_INTERVAL == 0
or websocket.streaming_state["is_final"]
) and len(frames_asr) != 0:
audio_in = b"".join(frames_asr)
@ -528,80 +546,109 @@ async def ws_serve(websocket, path):
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
traceback.print_exc()
if speech_start:
frames_asr.append(message)
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
except:
print("error in vad")
if speech_start_i != -1:
if not args.no_vad:
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
except Exception as e:
print(f"error in vad, {e}")
traceback.print_exc()
if speech_start_i != -1:
speech_start = True
speech_end_i = -1
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
else:
speech_start = True
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms + VAD_SEG_LOOKBACK_FRAME
frames_pre = frames[-beg_bias:]
speech_end_i = -1
frames_asr = []
frames_asr.extend(frames_pre)
frames_asr.extend(frames)
# vad end
if speech_end_i != -1 or not websocket.is_speaking:
end_bias = max((websocket.vad_pre_idx - speech_end_i) // duration_ms - VAD_SEG_LOOKAHEAD_FRAME, 0)
frames_asr = frames_asr[:-end_bias]
audio_in = b"".join(frames_asr)
try:
await streaming_transcribe(
websocket, audio_in, is_vad_end=True, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
)
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
websocket.streaming_state["previous_vad_audio"] = websocket.streaming_state.get("previous_vad_audio", []) + [audio_in]
if speech_end_i != -1:
audio_in = b"".join(frames_asr)
try:
await streaming_transcribe(
websocket, audio_in, is_vad_end=True, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
)
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
traceback.print_exc()
frames_asr = []
speech_start = False
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "")
now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "")
if (
len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1]))
< MIN_LEN_PER_PARAGRAPH
or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1]))
< MIN_LEN_PER_PARAGRAPH
):
if (
now_onscreen_asr_res.endswith(".")
or now_onscreen_asr_res.endswith("?")
or now_onscreen_asr_res.endswith("!")
):
now_onscreen_asr_res += " "
if (
now_onscreen_s2tt_res.endswith(".")
or now_onscreen_s2tt_res.endswith("?")
or now_onscreen_s2tt_res.endswith("!")
):
now_onscreen_s2tt_res += " "
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
now_onscreen_asr_res + "<vad>"
)
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
now_onscreen_s2tt_res + "<vad>"
if not websocket.is_speaking:
message = json.dumps(
{
"mode": "online",
"asr_text": websocket.streaming_state["onscreen_asr_res"] + "<em></em>",
"s2tt_text": websocket.streaming_state["onscreen_s2tt_res"] + "<em></em>",
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
"is_sentence_end": True,
}
)
await websocket.send(message)
await clear_websocket()
if args.return_sentence:
websocket.streaming_state["previous_vad_onscreen_asr_text"] = ""
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ""
else:
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
now_onscreen_asr_res + "<vad>" + "\n\n"
)
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
now_onscreen_s2tt_res + "<vad>" + "\n\n"
)
now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "")
now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "")
if (
len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1]))
< MIN_LEN_PER_PARAGRAPH
or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1]))
< MIN_LEN_PER_PARAGRAPH
):
if (
now_onscreen_asr_res.endswith(".")
or now_onscreen_asr_res.endswith("?")
or now_onscreen_asr_res.endswith("!")
):
now_onscreen_asr_res += " "
if (
now_onscreen_s2tt_res.endswith(".")
or now_onscreen_s2tt_res.endswith("?")
or now_onscreen_s2tt_res.endswith("!")
):
now_onscreen_s2tt_res += " "
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
now_onscreen_asr_res
)
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
now_onscreen_s2tt_res
)
else:
websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
now_onscreen_asr_res + "\n\n"
)
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
now_onscreen_s2tt_res + "\n\n"
)
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
websocket.streaming_state["previous_vad_audio"] = []
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
websocket.streaming_state["onscreen_asr_res"] = ""
websocket.streaming_state["onscreen_s2tt_res"] = ""
websocket.streaming_state["concat_asr_text"] = []
websocket.streaming_state["concat_s2tt_text"] = []
websocket.streaming_state["concat_audio"] = []
websocket.streaming_state["concat_audio_embedding"] = []
websocket.streaming_state["concat_audio_embedding_lens"] = []
else:
frames = frames[-20:]
else:
@ -617,8 +664,8 @@ async def ws_serve(websocket, path):
async def async_vad(websocket, audio_in):
# current_time = datetime.now()
# print("DEBUG:" + str(current_time) + " call vad function:")
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " call vad function:")
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
# print(segments_result)