diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index d66afb691..9765a4d91 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -972,11 +972,17 @@ class LLMASR4(nn.Module):
lora_init_param_path = lora_conf.get("init_param_path", None)
if lora_init_param_path is not None:
+ logging.info(f"lora_init_param_path: {lora_init_param_path}")
model = PeftModel.from_pretrained(model, lora_init_param_path)
+ for name, param in model.named_parameters():
+ if not lora_conf.get("freeze_lora", False):
+ if "lora_" in name:
+ param.requires_grad = True
else:
peft_config = LoraConfig(**lora_conf)
model = get_peft_model(model, peft_config)
- model.print_trainable_parameters()
+
+ model.print_trainable_parameters()
if llm_conf.get("activation_checkpoint", False):
model.gradient_checkpointing_enable()
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 8630aa7ea..96bf0f315 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -225,6 +225,14 @@ class Trainer:
ckpt_name = f"ds-model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
+ if self.use_lora:
+ lora_outdir = f"{self.output_dir}/lora-{ckpt_name}"
+ os.makedirs(lora_outdir, exist_ok=True)
+ if hasattr(model, "module"):
+ model.module.llm.save_pretrained(lora_outdir)
+ else:
+ model.llm.save_pretrained(lora_outdir)
+
# torch.save(state, filename)
with torch.no_grad():
model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state)
diff --git a/runtime/python/websocket/funasr_wss_client_llm.py b/runtime/python/websocket/funasr_wss_client_llm.py
index 2969dac6a..eb5c5d220 100644
--- a/runtime/python/websocket/funasr_wss_client_llm.py
+++ b/runtime/python/websocket/funasr_wss_client_llm.py
@@ -20,7 +20,7 @@ parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="localhost", required=False, help="host ip, localhost, 0.0.0.0"
)
-parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
+parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port")
parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk")
parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk")
parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk")
@@ -42,7 +42,7 @@ parser.add_argument(
parser.add_argument("--thread_num", type=int, default=1, help="thread_num")
parser.add_argument("--words_max_print", type=int, default=10000, help="chunk")
parser.add_argument("--output_dir", type=str, default=None, help="output_dir")
-parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl")
+parser.add_argument("--ssl", type=int, default=0, help="1 for ssl connect, 0 for no ssl")
parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn")
parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass")
@@ -238,7 +238,7 @@ async def record_from_scp(chunk_begin, chunk_size):
while not offline_msg_done:
await asyncio.sleep(1)
- #await websocket.close()
+ # await websocket.close()
async def message(id):
@@ -255,8 +255,8 @@ async def message(id):
ibest_writer = None
try:
timestamp = int(time.time())
- file_name = f'tts_client_{timestamp}.pcm'
- file = open(file_name, 'wb')
+ file_name = f"tts_client_{timestamp}.pcm"
+ file = open(file_name, "wb")
while True:
meg = await websocket.recv()
if isinstance(meg, bytes):
@@ -306,7 +306,7 @@ async def message(id):
text_print_2pass_online = ""
text_print = text_print_2pass_offline + "{}".format(text)
text_print_2pass_offline += "{}".format(text)
- #text_print = text_print[-args.words_max_print :]
+ # text_print = text_print[-args.words_max_print :]
os.system("clear")
print("tts_count len: =>{}".format(tts_count))
print("\rpid" + str(id) + ": " + text_print)
@@ -316,6 +316,8 @@ async def message(id):
print("Exception:", e)
# traceback.print_exc()
# await websocket.close()
+
+
async def ws_client(id, chunk_begin, chunk_size):
if args.audio_in is None:
chunk_begin = 0
diff --git a/runtime/python/websocket/funasr_wss_server_llm.py b/runtime/python/websocket/funasr_wss_server_llm.py
index 6ad95672d..2c7fe6a57 100644
--- a/runtime/python/websocket/funasr_wss_server_llm.py
+++ b/runtime/python/websocket/funasr_wss_server_llm.py
@@ -11,8 +11,16 @@ import nls
from collections import deque
import threading
+
class NlsTtsSynthesizer:
- def __init__(self, websocket, tts_fifo, token, appkey, url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1"):
+ def __init__(
+ self,
+ websocket,
+ tts_fifo,
+ token,
+ appkey,
+ url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1",
+ ):
self.websocket = websocket
self.tts_fifo = tts_fifo
self.url = url
@@ -36,30 +44,33 @@ class NlsTtsSynthesizer:
on_completed=self.on_completed,
on_error=self.on_error,
on_close=self.on_close,
- callback_args=[]
+ callback_args=[],
)
+
def on_data(self, data, *args):
self.count += len(data)
+ print(f"cout: {self.count}")
self.tts_fifo.append(data)
- #with open('tts_server.pcm', 'ab') as file:
+ # with open('tts_server.pcm', 'ab') as file:
# file.write(data)
+
def on_sentence_begin(self, message, *args):
- print('on sentence begin =>{}'.format(message))
+ print("on sentence begin =>{}".format(message))
def on_sentence_synthesis(self, message, *args):
- print('on sentence synthesis =>{}'.format(message))
+ print("on sentence synthesis =>{}".format(message))
def on_sentence_end(self, message, *args):
- print('on sentence end =>{}'.format(message))
+ print("on sentence end =>{}".format(message))
def on_completed(self, message, *args):
- print('on completed =>{}'.format(message))
+ print("on completed =>{}".format(message))
def on_error(self, message, *args):
- print('on_error args=>{}'.format(args))
+ print("on_error args=>{}".format(args))
def on_close(self, *args):
- print('on_close: args=>{}'.format(args))
+ print("on_close: args=>{}".format(args))
print("on message data cout: =>{}".format(self.count))
self.started = False
@@ -68,16 +79,18 @@ class NlsTtsSynthesizer:
self.started = True
def send_text(self, text):
+ print(f"text: {text}")
self.sdk.sendStreamInputTts(text)
async def stop(self):
self.sdk.stopStreamInputTts()
+
parser = argparse.ArgumentParser()
parser.add_argument(
- "--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
+ "--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0"
)
-parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
+parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port")
parser.add_argument(
"--asr_model",
type=str,
@@ -124,16 +137,6 @@ websocket_users = set()
print("model loading")
from funasr import AutoModel
-# # asr
-# model_asr = AutoModel(
-# model=args.asr_model,
-# model_revision=args.asr_model_revision,
-# ngpu=args.ngpu,
-# ncpu=args.ncpu,
-# device=args.device,
-# disable_pbar=False,
-# disable_log=True,
-# )
# vad
model_vad = AutoModel(
@@ -147,26 +150,6 @@ model_vad = AutoModel(
# chunk_size=60,
)
-# async def async_asr(websocket, audio_in):
-# if len(audio_in) > 0:
-# # print(len(audio_in))
-# print(type(audio_in))
-# rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
-# print("offline_asr, ", rec_result)
-#
-#
-# if len(rec_result["text"]) > 0:
-# # print("offline", rec_result)
-# mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
-# message = json.dumps(
-# {
-# "mode": mode,
-# "text": rec_result["text"],
-# "wav_name": websocket.wav_name,
-# "is_final": websocket.is_speaking,
-# }
-# )
-# await websocket.send(message)
import os
@@ -205,20 +188,26 @@ if "key" in os.environ:
key = os.environ["key"]
api.login(key)
+appkey = "xxx"
+appkey_token = "xxx"
+if "appkey" in os.environ:
+ appkey = os.environ["appkey"]
+ appkey_token = os.environ["appkey_token"]
+
from modelscope.hub.snapshot_download import snapshot_download
-# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
-# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
-# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
+os.environ["MODELSCOPE_CACHE"] = "/mnt/workspace"
+llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
+audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
-llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
-audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope"
+# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
+# audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope"
device = "cuda:0"
all_file_paths = [
- "/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope"
- # "FunAudioLLM/Speech2Text_Align_V0712",
+ # "/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope"
+ "FunAudioLLM/Speech2Text_Align_V0712",
# "FunAudioLLM/Speech2Text_Align_V0718",
# "FunAudioLLM/Speech2Text_Align_V0628",
]
@@ -246,6 +235,7 @@ tokenizer = model_llm.kwargs["tokenizer"]
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
+
async def send_to_client(websocket, syntheszier, tts_fifo):
# Sending tts data to the client
while True:
@@ -260,6 +250,8 @@ async def send_to_client(websocket, syntheszier, tts_fifo):
else:
print("WebSocket connection is not open or syntheszier is not started.")
break
+
+
async def model_inference(
websocket,
audio_in,
@@ -271,7 +263,9 @@ async def model_inference(
text_usr="",
):
fifo_queue = deque()
- synthesizer = NlsTtsSynthesizer(websocket=websocket, tts_fifo=fifo_queue, token="xxx", appkey="xxx")
+ synthesizer = NlsTtsSynthesizer(
+ websocket=websocket, tts_fifo=fifo_queue, token=appkey_token, appkey=appkey
+ )
synthesizer.start()
beg0 = time.time()
if his_state is None:
@@ -329,8 +323,10 @@ async def model_inference(
f"generated new text: {new_text}, time_fr_receive: {end_llm - beg0:.2f}, time_llm_decode: {end_llm - beg_llm:.2f}"
)
if len(new_text) > 0:
+ new_text = new_text.replace("<|im_end|>", "")
+ res += new_text
synthesizer.send_text(new_text)
- res += new_text.replace("<|im_end|>", "")
+
contents_i[-1]["content"] = res
websocket.llm_state["contents_i"] = contents_i
# history[-1][1] = res
@@ -341,7 +337,7 @@ async def model_inference(
"mode": mode,
"text": new_text,
"wav_name": websocket.wav_name,
- "is_final": websocket.is_speaking,
+ "is_final": False,
}
)
# print(f"online: {message}")
@@ -350,6 +346,7 @@ async def model_inference(
while len(fifo_queue) > 0:
await websocket.send(fifo_queue.popleft())
+ # synthesizer.send_text(res)
tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue))
synthesizer.stop()
await tts_to_client_task
@@ -359,7 +356,7 @@ async def model_inference(
"mode": mode,
"text": res,
"wav_name": websocket.wav_name,
- "is_final": websocket.is_speaking,
+ "is_final": True,
}
)
# print(f"offline: {message}")
@@ -460,7 +457,8 @@ async def ws_serve(websocket, path):
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
- if speech_end_i != -1 or not websocket.is_speaking:
+ # if speech_end_i != -1 or not websocket.is_speaking:
+ if not websocket.is_speaking:
# print("vad end point")
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
@@ -506,7 +504,7 @@ async def async_vad(websocket, audio_in):
return speech_start, speech_end
-if len(args.certfile) > 0:
+if False: # len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
diff --git a/runtime/python/websocket/funasr_wss_server_streaming_llm.py b/runtime/python/websocket/funasr_wss_server_streaming_llm.py
index 535a2828b..e61a8c3e3 100644
--- a/runtime/python/websocket/funasr_wss_server_streaming_llm.py
+++ b/runtime/python/websocket/funasr_wss_server_streaming_llm.py
@@ -3,6 +3,7 @@ import asyncio
import json
import websockets
import time
+from datetime import datetime
import argparse
import ssl
import numpy as np
@@ -14,7 +15,7 @@ from modelscope.hub.snapshot_download import snapshot_download
parser = argparse.ArgumentParser()
parser.add_argument(
- "--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
+ "--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument(
@@ -30,14 +31,14 @@ parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
parser.add_argument(
"--certfile",
type=str,
- default="../../ssl_key/server.crt",
+ default="",
required=False,
help="certfile for ssl",
)
parser.add_argument(
"--keyfile",
type=str,
- default="../../ssl_key/server.key",
+ default="ssl_key/server.key",
required=False,
help="keyfile for ssl",
)
@@ -61,20 +62,25 @@ model_vad = AutoModel(
)
api = HubApi()
+key = "ed70b703-9ec7-44b8-b5ce-5f4527719810"
+api.login(key)
if "key" in os.environ:
key = os.environ["key"]
- api.login(key)
+api.login(key)
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
-# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
-# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
+llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master")
+audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master")
-llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
-audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
+# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
+# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
device = "cuda:0"
all_file_paths = [
- "/nfs/yangyexin.yyx/init_model/qwen2_7b_mmt_v14_20240830/",
- "/nfs/yangyexin.yyx/init_model/audiolm_v14_20240824_train_encoder_all_20240822_lr1e-4_warmup2350/"
+ "FunAudioLLM/qwen2_7b_mmt_v14_20240830",
+ "FunAudioLLM/audiolm_v11_20240807",
+ "FunAudioLLM/Speech2Text_Align_V0712",
+ "FunAudioLLM/Speech2Text_Align_V0718",
+ "FunAudioLLM/Speech2Text_Align_V0628",
]
llm_kwargs = {"num_beams": 1, "do_sample": False}
@@ -85,6 +91,7 @@ MAX_ITER_PER_CHUNK = 20
ckpt_dir = all_file_paths[0]
+
def contains_lora_folder(directory):
for name in os.listdir(directory):
full_path = os.path.join(directory, name)
@@ -92,7 +99,10 @@ def contains_lora_folder(directory):
return full_path
return None
+
+ckpt_dir = snapshot_download(ckpt_dir, cache_dir=None, revision="master")
lora_folder = contains_lora_folder(ckpt_dir)
+
if lora_folder is not None:
model_llm = AutoModel(
model=ckpt_dir,
@@ -102,7 +112,11 @@ if lora_folder is not None:
llm_dtype="bf16",
max_length=1024,
llm_kwargs=llm_kwargs,
- llm_conf={"init_param_path": llm_dir, "lora_conf": {"init_param_path": lora_folder}, "load_kwargs": {"attn_implementation": "eager"}},
+ llm_conf={
+ "init_param_path": llm_dir,
+ "lora_conf": {"init_param_path": lora_folder},
+ "load_kwargs": {"attn_implementation": "eager"},
+ },
tokenizer_conf={"init_param_path": llm_dir},
audio_encoder=audio_encoder_dir,
)
@@ -128,6 +142,7 @@ model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
print("model loaded! only support one client at the same time now!!!!")
+
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
@@ -143,7 +158,12 @@ def load_bytes(input):
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
-async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None):
+
+async def streaming_transcribe(
+ websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None
+):
+ current_time = datetime.now()
+ print("DEBUG:" + str(current_time) + " call streaming_transcribe function:")
if his_state is None:
his_state = model_dict
model = his_state["model"]
@@ -157,17 +177,21 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
else:
previous_asr_text = websocket.streaming_state.get("previous_asr_text", "")
previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "")
- previous_vad_onscreen_asr_text = websocket.streaming_state.get("previous_vad_onscreen_asr_text", "")
- previous_vad_onscreen_s2tt_text = websocket.streaming_state.get("previous_vad_onscreen_s2tt_text", "")
-
+ previous_vad_onscreen_asr_text = websocket.streaming_state.get(
+ "previous_vad_onscreen_asr_text", ""
+ )
+ previous_vad_onscreen_s2tt_text = websocket.streaming_state.get(
+ "previous_vad_onscreen_s2tt_text", ""
+ )
+
if asr_prompt is None or asr_prompt == "":
- asr_prompt = "Copy:"
+ asr_prompt = "Speech transcription:"
if s2tt_prompt is None or s2tt_prompt == "":
- s2tt_prompt = "Translate the following sentence into English:"
+ s2tt_prompt = "Translate into English:"
audio_seconds = load_bytes(audio_in).shape[0] / 16000
print(f"Streaming audio length: {audio_seconds} seconds")
-
+
asr_content = []
system_prompt = "You are a helpful assistant."
asr_content.append({"role": "system", "content": system_prompt})
@@ -185,12 +209,24 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
streaming_time_beg = time.time()
inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
- [asr_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
+ [asr_content],
+ None,
+ "test_demo",
+ tokenizer,
+ frontend,
+ device=device,
+ infer_with_assistant_input=True,
)
model_asr_inputs = {}
model_asr_inputs["inputs_embeds"] = inputs_asr_embeds
inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
- [s2tt_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
+ [s2tt_content],
+ None,
+ "test_demo",
+ tokenizer,
+ frontend,
+ device=device,
+ infer_with_assistant_input=True,
)
model_s2tt_inputs = {}
model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds
@@ -208,7 +244,7 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
s2tt_generation_kwargs.update(llm_kwargs)
s2tt_thread = Thread(target=model.llm.generate, kwargs=s2tt_generation_kwargs)
s2tt_thread.start()
-
+
onscreen_asr_res = previous_asr_text
onscreen_s2tt_res = previous_s2tt_text
@@ -220,8 +256,8 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
is_s2tt_repetition = False
for new_asr_text in asr_streamer:
- print(f"generated new asr text: {new_asr_text}")
-
+ current_time = datetime.now()
+ print("DEBUG: " + str(current_time) + " " + f"generated new asr text: {new_asr_text}")
if len(new_asr_text) > 0:
onscreen_asr_res += new_asr_text.replace("<|im_end|>", "")
if len(new_asr_text.replace("<|im_end|>", "")) > 0:
@@ -233,7 +269,13 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
if remain_s2tt_text:
try:
new_s2tt_text = next(s2tt_streamer)
- print(f"generated new s2tt text: {new_s2tt_text}")
+ current_time = datetime.now()
+ print(
+ "DEBUG: "
+ + str(current_time)
+ + " "
+ + f"generated new s2tt text: {new_s2tt_text}"
+ )
s2tt_iter_cnt += 1
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
@@ -241,16 +283,16 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
new_s2tt_text = ""
remain_s2tt_text = False
pass
-
+
if len(new_asr_text) > 0 or len(new_s2tt_text) > 0:
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text
- unfix_asr_part = all_asr_res[len(fix_asr_part):]
- return_asr_res = fix_asr_part + ""+ unfix_asr_part + ""
+ unfix_asr_part = all_asr_res[len(fix_asr_part) :]
+ return_asr_res = fix_asr_part + "" + unfix_asr_part + ""
all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text
- unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):]
- return_s2tt_res = fix_s2tt_part + ""+ unfix_s2tt_part + ""
+ unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part) :]
+ return_s2tt_res = fix_s2tt_part + "" + unfix_s2tt_part + ""
message = json.dumps(
{
"mode": "online",
@@ -261,14 +303,19 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
}
)
await websocket.send(message)
- websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
- websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
+ websocket.streaming_state["onscreen_asr_res"] = (
+ previous_vad_onscreen_asr_text + onscreen_asr_res
+ )
+ websocket.streaming_state["onscreen_s2tt_res"] = (
+ previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
+ )
-
if remain_s2tt_text:
for new_s2tt_text in s2tt_streamer:
- print(f"generated new s2tt text: {new_s2tt_text}")
-
+ current_time = datetime.now()
+ print(
+ "DEBUG: " + str(current_time) + " " + f"generated new s2tt text: {new_s2tt_text}"
+ )
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
if len(new_s2tt_text.replace("<|im_end|>", "")) > 0:
@@ -276,16 +323,16 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
if s2tt_iter_cnt > MAX_ITER_PER_CHUNK:
is_s2tt_repetition = True
break
-
+
if len(new_s2tt_text) > 0:
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text
- unfix_asr_part = all_asr_res[len(fix_asr_part):]
- return_asr_res = fix_asr_part + ""+ unfix_asr_part + ""
+ unfix_asr_part = all_asr_res[len(fix_asr_part) :]
+ return_asr_res = fix_asr_part + "" + unfix_asr_part + ""
all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text
- unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):]
- return_s2tt_res = fix_s2tt_part + ""+ unfix_s2tt_part + ""
+ unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part) :]
+ return_s2tt_res = fix_s2tt_part + "" + unfix_s2tt_part + ""
message = json.dumps(
{
"mode": "online",
@@ -296,8 +343,12 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
}
)
await websocket.send(message)
- websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
- websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
+ websocket.streaming_state["onscreen_asr_res"] = (
+ previous_vad_onscreen_asr_text + onscreen_asr_res
+ )
+ websocket.streaming_state["onscreen_s2tt_res"] = (
+ previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
+ )
streaming_time_end = time.time()
print(f"Streaming inference time: {streaming_time_end - streaming_time_beg}")
@@ -307,16 +358,24 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_asr_repetition:
pre_previous_asr_text = previous_asr_text
- previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN]).replace("�", "")
+ previous_asr_text = tokenizer.decode(
+ tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN]
+ ).replace("�", "")
if len(previous_asr_text) <= len(pre_previous_asr_text):
previous_asr_text = pre_previous_asr_text
elif is_asr_repetition:
pass
else:
previous_asr_text = ""
- if s2tt_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_s2tt_repetition:
+ if (
+ s2tt_text_len > UNFIX_LEN
+ and audio_seconds > MIN_LEN_SEC_AUDIO_FIX
+ and not is_s2tt_repetition
+ ):
pre_previous_s2tt_text = previous_s2tt_text
- previous_s2tt_text = tokenizer.decode(tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN]).replace("�", "")
+ previous_s2tt_text = tokenizer.decode(
+ tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN]
+ ).replace("�", "")
if len(previous_s2tt_text) <= len(pre_previous_s2tt_text):
previous_s2tt_text = pre_previous_s2tt_text
elif is_s2tt_repetition:
@@ -325,9 +384,13 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N
previous_s2tt_text = ""
websocket.streaming_state["previous_asr_text"] = previous_asr_text
- websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
+ websocket.streaming_state["onscreen_asr_res"] = (
+ previous_vad_onscreen_asr_text + onscreen_asr_res
+ )
websocket.streaming_state["previous_s2tt_text"] = previous_s2tt_text
- websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
+ websocket.streaming_state["onscreen_s2tt_res"] = (
+ previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
+ )
print("fix asr part:", previous_asr_text)
print("fix s2tt part:", previous_s2tt_text)
@@ -383,6 +446,13 @@ async def ws_serve(websocket, path):
try:
async for message in websocket:
+ if isinstance(message, str):
+ current_time = datetime.now()
+ print("DEBUG:" + str(current_time) + " received message:", message)
+ else:
+ current_time = datetime.now()
+ print("DEBUG:" + str(current_time) + " received audio bytes:")
+
if isinstance(message, str):
messagejson = json.loads(message)
@@ -403,11 +473,11 @@ async def ws_serve(websocket, path):
if "asr_prompt" in messagejson:
asr_prompt = messagejson["asr_prompt"]
else:
- asr_prompt = "Copy:"
+ asr_prompt = "Speech transcription:"
if "s2tt_prompt" in messagejson:
s2tt_prompt = messagejson["s2tt_prompt"]
else:
- s2tt_prompt = "Translate the following sentence into English:"
+ s2tt_prompt = "Translate into English:"
websocket.status_dict_vad["chunk_size"] = int(
chunk_size[1] * 60 / websocket.chunk_interval
@@ -421,19 +491,20 @@ async def ws_serve(websocket, path):
# asr online
websocket.streaming_state["is_final"] = speech_end_i != -1
if (
- (len(frames_asr) % websocket.chunk_interval == 0
- or websocket.streaming_state["is_final"])
- and len(frames_asr) != 0
- ):
+ len(frames_asr) % websocket.chunk_interval == 0
+ or websocket.streaming_state["is_final"]
+ ) and len(frames_asr) != 0:
audio_in = b"".join(frames_asr)
try:
- await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt)
+ await streaming_transcribe(
+ websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
+ )
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
if speech_start:
frames_asr.append(message)
-
+
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
@@ -450,7 +521,9 @@ async def ws_serve(websocket, path):
if speech_end_i != -1 or not websocket.is_speaking:
audio_in = b"".join(frames_asr)
try:
- await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt)
+ await streaming_transcribe(
+ websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
+ )
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
@@ -460,16 +533,37 @@ async def ws_serve(websocket, path):
websocket.streaming_state["previous_s2tt_text"] = ""
now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "")
now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "")
- if len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH:
- if now_onscreen_asr_res.endswith(".") or now_onscreen_asr_res.endswith("?") or now_onscreen_asr_res.endswith("!"):
+ if (
+ len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1]))
+ < MIN_LEN_PER_PARAGRAPH
+ or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1]))
+ < MIN_LEN_PER_PARAGRAPH
+ ):
+ if (
+ now_onscreen_asr_res.endswith(".")
+ or now_onscreen_asr_res.endswith("?")
+ or now_onscreen_asr_res.endswith("!")
+ ):
now_onscreen_asr_res += " "
- if now_onscreen_s2tt_res.endswith(".") or now_onscreen_s2tt_res.endswith("?") or now_onscreen_s2tt_res.endswith("!"):
+ if (
+ now_onscreen_s2tt_res.endswith(".")
+ or now_onscreen_s2tt_res.endswith("?")
+ or now_onscreen_s2tt_res.endswith("!")
+ ):
now_onscreen_s2tt_res += " "
- websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res
- websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res
+ websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
+ now_onscreen_asr_res
+ )
+ websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
+ now_onscreen_s2tt_res
+ )
else:
- websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res + "\n\n"
- websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res + "\n\n"
+ websocket.streaming_state["previous_vad_onscreen_asr_text"] = (
+ now_onscreen_asr_res + "\n\n"
+ )
+ websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = (
+ now_onscreen_s2tt_res + "\n\n"
+ )
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
@@ -491,6 +585,8 @@ async def ws_serve(websocket, path):
async def async_vad(websocket, audio_in):
+ current_time = datetime.now()
+ print("DEBUG:" + str(current_time) + " call vad function:")
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
# print(segments_result)
@@ -506,7 +602,7 @@ async def async_vad(websocket, audio_in):
return speech_start, speech_end
-if len(args.certfile) > 0:
+if False:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions