diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index d66afb691..9765a4d91 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -972,11 +972,17 @@ class LLMASR4(nn.Module): lora_init_param_path = lora_conf.get("init_param_path", None) if lora_init_param_path is not None: + logging.info(f"lora_init_param_path: {lora_init_param_path}") model = PeftModel.from_pretrained(model, lora_init_param_path) + for name, param in model.named_parameters(): + if not lora_conf.get("freeze_lora", False): + if "lora_" in name: + param.requires_grad = True else: peft_config = LoraConfig(**lora_conf) model = get_peft_model(model, peft_config) - model.print_trainable_parameters() + + model.print_trainable_parameters() if llm_conf.get("activation_checkpoint", False): model.gradient_checkpointing_enable() diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index 8630aa7ea..96bf0f315 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -225,6 +225,14 @@ class Trainer: ckpt_name = f"ds-model.pt.ep{epoch}.{step}" filename = os.path.join(self.output_dir, ckpt_name) + if self.use_lora: + lora_outdir = f"{self.output_dir}/lora-{ckpt_name}" + os.makedirs(lora_outdir, exist_ok=True) + if hasattr(model, "module"): + model.module.llm.save_pretrained(lora_outdir) + else: + model.llm.save_pretrained(lora_outdir) + # torch.save(state, filename) with torch.no_grad(): model.save_checkpoint(save_dir=self.output_dir, tag=ckpt_name, client_state=state) diff --git a/runtime/python/websocket/funasr_wss_client_llm.py b/runtime/python/websocket/funasr_wss_client_llm.py index 2969dac6a..eb5c5d220 100644 --- a/runtime/python/websocket/funasr_wss_client_llm.py +++ b/runtime/python/websocket/funasr_wss_client_llm.py @@ -20,7 +20,7 @@ parser = argparse.ArgumentParser() parser.add_argument( "--host", type=str, default="localhost", required=False, help="host ip, localhost, 0.0.0.0" ) -parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port") +parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port") parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk") parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk") parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk") @@ -42,7 +42,7 @@ parser.add_argument( parser.add_argument("--thread_num", type=int, default=1, help="thread_num") parser.add_argument("--words_max_print", type=int, default=10000, help="chunk") parser.add_argument("--output_dir", type=str, default=None, help="output_dir") -parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl") +parser.add_argument("--ssl", type=int, default=0, help="1 for ssl connect, 0 for no ssl") parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn") parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass") @@ -238,7 +238,7 @@ async def record_from_scp(chunk_begin, chunk_size): while not offline_msg_done: await asyncio.sleep(1) - #await websocket.close() + # await websocket.close() async def message(id): @@ -255,8 +255,8 @@ async def message(id): ibest_writer = None try: timestamp = int(time.time()) - file_name = f'tts_client_{timestamp}.pcm' - file = open(file_name, 'wb') + file_name = f"tts_client_{timestamp}.pcm" + file = open(file_name, "wb") while True: meg = await websocket.recv() if isinstance(meg, bytes): @@ -306,7 +306,7 @@ async def message(id): text_print_2pass_online = "" text_print = text_print_2pass_offline + "{}".format(text) text_print_2pass_offline += "{}".format(text) - #text_print = text_print[-args.words_max_print :] + # text_print = text_print[-args.words_max_print :] os.system("clear") print("tts_count len: =>{}".format(tts_count)) print("\rpid" + str(id) + ": " + text_print) @@ -316,6 +316,8 @@ async def message(id): print("Exception:", e) # traceback.print_exc() # await websocket.close() + + async def ws_client(id, chunk_begin, chunk_size): if args.audio_in is None: chunk_begin = 0 diff --git a/runtime/python/websocket/funasr_wss_server_llm.py b/runtime/python/websocket/funasr_wss_server_llm.py index 6ad95672d..2c7fe6a57 100644 --- a/runtime/python/websocket/funasr_wss_server_llm.py +++ b/runtime/python/websocket/funasr_wss_server_llm.py @@ -11,8 +11,16 @@ import nls from collections import deque import threading + class NlsTtsSynthesizer: - def __init__(self, websocket, tts_fifo, token, appkey, url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1"): + def __init__( + self, + websocket, + tts_fifo, + token, + appkey, + url="wss://nls-gateway-cn-beijing.aliyuncs.com/ws/v1", + ): self.websocket = websocket self.tts_fifo = tts_fifo self.url = url @@ -36,30 +44,33 @@ class NlsTtsSynthesizer: on_completed=self.on_completed, on_error=self.on_error, on_close=self.on_close, - callback_args=[] + callback_args=[], ) + def on_data(self, data, *args): self.count += len(data) + print(f"cout: {self.count}") self.tts_fifo.append(data) - #with open('tts_server.pcm', 'ab') as file: + # with open('tts_server.pcm', 'ab') as file: # file.write(data) + def on_sentence_begin(self, message, *args): - print('on sentence begin =>{}'.format(message)) + print("on sentence begin =>{}".format(message)) def on_sentence_synthesis(self, message, *args): - print('on sentence synthesis =>{}'.format(message)) + print("on sentence synthesis =>{}".format(message)) def on_sentence_end(self, message, *args): - print('on sentence end =>{}'.format(message)) + print("on sentence end =>{}".format(message)) def on_completed(self, message, *args): - print('on completed =>{}'.format(message)) + print("on completed =>{}".format(message)) def on_error(self, message, *args): - print('on_error args=>{}'.format(args)) + print("on_error args=>{}".format(args)) def on_close(self, *args): - print('on_close: args=>{}'.format(args)) + print("on_close: args=>{}".format(args)) print("on message data cout: =>{}".format(self.count)) self.started = False @@ -68,16 +79,18 @@ class NlsTtsSynthesizer: self.started = True def send_text(self, text): + print(f"text: {text}") self.sdk.sendStreamInputTts(text) async def stop(self): self.sdk.stopStreamInputTts() + parser = argparse.ArgumentParser() parser.add_argument( - "--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0" + "--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0" ) -parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port") +parser.add_argument("--port", type=int, default=10096, required=False, help="grpc server port") parser.add_argument( "--asr_model", type=str, @@ -124,16 +137,6 @@ websocket_users = set() print("model loading") from funasr import AutoModel -# # asr -# model_asr = AutoModel( -# model=args.asr_model, -# model_revision=args.asr_model_revision, -# ngpu=args.ngpu, -# ncpu=args.ncpu, -# device=args.device, -# disable_pbar=False, -# disable_log=True, -# ) # vad model_vad = AutoModel( @@ -147,26 +150,6 @@ model_vad = AutoModel( # chunk_size=60, ) -# async def async_asr(websocket, audio_in): -# if len(audio_in) > 0: -# # print(len(audio_in)) -# print(type(audio_in)) -# rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0] -# print("offline_asr, ", rec_result) -# -# -# if len(rec_result["text"]) > 0: -# # print("offline", rec_result) -# mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode -# message = json.dumps( -# { -# "mode": mode, -# "text": rec_result["text"], -# "wav_name": websocket.wav_name, -# "is_final": websocket.is_speaking, -# } -# ) -# await websocket.send(message) import os @@ -205,20 +188,26 @@ if "key" in os.environ: key = os.environ["key"] api.login(key) +appkey = "xxx" +appkey_token = "xxx" +if "appkey" in os.environ: + appkey = os.environ["appkey"] + appkey_token = os.environ["appkey_token"] + from modelscope.hub.snapshot_download import snapshot_download -# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope" -# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master') -# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master') +os.environ["MODELSCOPE_CACHE"] = "/mnt/workspace" +llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master") +audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master") -llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct" -audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope" +# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct" +# audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope" device = "cuda:0" all_file_paths = [ - "/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope" - # "FunAudioLLM/Speech2Text_Align_V0712", + # "/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope" + "FunAudioLLM/Speech2Text_Align_V0712", # "FunAudioLLM/Speech2Text_Align_V0718", # "FunAudioLLM/Speech2Text_Align_V0628", ] @@ -246,6 +235,7 @@ tokenizer = model_llm.kwargs["tokenizer"] model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer} + async def send_to_client(websocket, syntheszier, tts_fifo): # Sending tts data to the client while True: @@ -260,6 +250,8 @@ async def send_to_client(websocket, syntheszier, tts_fifo): else: print("WebSocket connection is not open or syntheszier is not started.") break + + async def model_inference( websocket, audio_in, @@ -271,7 +263,9 @@ async def model_inference( text_usr="", ): fifo_queue = deque() - synthesizer = NlsTtsSynthesizer(websocket=websocket, tts_fifo=fifo_queue, token="xxx", appkey="xxx") + synthesizer = NlsTtsSynthesizer( + websocket=websocket, tts_fifo=fifo_queue, token=appkey_token, appkey=appkey + ) synthesizer.start() beg0 = time.time() if his_state is None: @@ -329,8 +323,10 @@ async def model_inference( f"generated new text: {new_text}, time_fr_receive: {end_llm - beg0:.2f}, time_llm_decode: {end_llm - beg_llm:.2f}" ) if len(new_text) > 0: + new_text = new_text.replace("<|im_end|>", "") + res += new_text synthesizer.send_text(new_text) - res += new_text.replace("<|im_end|>", "") + contents_i[-1]["content"] = res websocket.llm_state["contents_i"] = contents_i # history[-1][1] = res @@ -341,7 +337,7 @@ async def model_inference( "mode": mode, "text": new_text, "wav_name": websocket.wav_name, - "is_final": websocket.is_speaking, + "is_final": False, } ) # print(f"online: {message}") @@ -350,6 +346,7 @@ async def model_inference( while len(fifo_queue) > 0: await websocket.send(fifo_queue.popleft()) + # synthesizer.send_text(res) tts_to_client_task = asyncio.create_task(send_to_client(websocket, synthesizer, fifo_queue)) synthesizer.stop() await tts_to_client_task @@ -359,7 +356,7 @@ async def model_inference( "mode": mode, "text": res, "wav_name": websocket.wav_name, - "is_final": websocket.is_speaking, + "is_final": True, } ) # print(f"offline: {message}") @@ -460,7 +457,8 @@ async def ws_serve(websocket, path): frames_asr = [] frames_asr.extend(frames_pre) # asr punc offline - if speech_end_i != -1 or not websocket.is_speaking: + # if speech_end_i != -1 or not websocket.is_speaking: + if not websocket.is_speaking: # print("vad end point") if websocket.mode == "2pass" or websocket.mode == "offline": audio_in = b"".join(frames_asr) @@ -506,7 +504,7 @@ async def async_vad(websocket, audio_in): return speech_start, speech_end -if len(args.certfile) > 0: +if False: # len(args.certfile) > 0: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions diff --git a/runtime/python/websocket/funasr_wss_server_streaming_llm.py b/runtime/python/websocket/funasr_wss_server_streaming_llm.py index 535a2828b..e61a8c3e3 100644 --- a/runtime/python/websocket/funasr_wss_server_streaming_llm.py +++ b/runtime/python/websocket/funasr_wss_server_streaming_llm.py @@ -3,6 +3,7 @@ import asyncio import json import websockets import time +from datetime import datetime import argparse import ssl import numpy as np @@ -14,7 +15,7 @@ from modelscope.hub.snapshot_download import snapshot_download parser = argparse.ArgumentParser() parser.add_argument( - "--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0" + "--host", type=str, default="0.0.0.0", required=False, help="host ip, localhost, 0.0.0.0" ) parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port") parser.add_argument( @@ -30,14 +31,14 @@ parser.add_argument("--ncpu", type=int, default=4, help="cpu cores") parser.add_argument( "--certfile", type=str, - default="../../ssl_key/server.crt", + default="", required=False, help="certfile for ssl", ) parser.add_argument( "--keyfile", type=str, - default="../../ssl_key/server.key", + default="ssl_key/server.key", required=False, help="keyfile for ssl", ) @@ -61,20 +62,25 @@ model_vad = AutoModel( ) api = HubApi() +key = "ed70b703-9ec7-44b8-b5ce-5f4527719810" +api.login(key) if "key" in os.environ: key = os.environ["key"] - api.login(key) +api.login(key) # os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope" -# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master') -# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master') +llm_dir = snapshot_download("qwen/Qwen2-7B-Instruct", cache_dir=None, revision="master") +audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision="master") -llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct" -audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712" +# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct" +# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712" device = "cuda:0" all_file_paths = [ - "/nfs/yangyexin.yyx/init_model/qwen2_7b_mmt_v14_20240830/", - "/nfs/yangyexin.yyx/init_model/audiolm_v14_20240824_train_encoder_all_20240822_lr1e-4_warmup2350/" + "FunAudioLLM/qwen2_7b_mmt_v14_20240830", + "FunAudioLLM/audiolm_v11_20240807", + "FunAudioLLM/Speech2Text_Align_V0712", + "FunAudioLLM/Speech2Text_Align_V0718", + "FunAudioLLM/Speech2Text_Align_V0628", ] llm_kwargs = {"num_beams": 1, "do_sample": False} @@ -85,6 +91,7 @@ MAX_ITER_PER_CHUNK = 20 ckpt_dir = all_file_paths[0] + def contains_lora_folder(directory): for name in os.listdir(directory): full_path = os.path.join(directory, name) @@ -92,7 +99,10 @@ def contains_lora_folder(directory): return full_path return None + +ckpt_dir = snapshot_download(ckpt_dir, cache_dir=None, revision="master") lora_folder = contains_lora_folder(ckpt_dir) + if lora_folder is not None: model_llm = AutoModel( model=ckpt_dir, @@ -102,7 +112,11 @@ if lora_folder is not None: llm_dtype="bf16", max_length=1024, llm_kwargs=llm_kwargs, - llm_conf={"init_param_path": llm_dir, "lora_conf": {"init_param_path": lora_folder}, "load_kwargs": {"attn_implementation": "eager"}}, + llm_conf={ + "init_param_path": llm_dir, + "lora_conf": {"init_param_path": lora_folder}, + "load_kwargs": {"attn_implementation": "eager"}, + }, tokenizer_conf={"init_param_path": llm_dir}, audio_encoder=audio_encoder_dir, ) @@ -128,6 +142,7 @@ model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer} print("model loaded! only support one client at the same time now!!!!") + def load_bytes(input): middle_data = np.frombuffer(input, dtype=np.int16) middle_data = np.asarray(middle_data) @@ -143,7 +158,12 @@ def load_bytes(input): array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) return array -async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None): + +async def streaming_transcribe( + websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None +): + current_time = datetime.now() + print("DEBUG:" + str(current_time) + " call streaming_transcribe function:") if his_state is None: his_state = model_dict model = his_state["model"] @@ -157,17 +177,21 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N else: previous_asr_text = websocket.streaming_state.get("previous_asr_text", "") previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "") - previous_vad_onscreen_asr_text = websocket.streaming_state.get("previous_vad_onscreen_asr_text", "") - previous_vad_onscreen_s2tt_text = websocket.streaming_state.get("previous_vad_onscreen_s2tt_text", "") - + previous_vad_onscreen_asr_text = websocket.streaming_state.get( + "previous_vad_onscreen_asr_text", "" + ) + previous_vad_onscreen_s2tt_text = websocket.streaming_state.get( + "previous_vad_onscreen_s2tt_text", "" + ) + if asr_prompt is None or asr_prompt == "": - asr_prompt = "Copy:" + asr_prompt = "Speech transcription:" if s2tt_prompt is None or s2tt_prompt == "": - s2tt_prompt = "Translate the following sentence into English:" + s2tt_prompt = "Translate into English:" audio_seconds = load_bytes(audio_in).shape[0] / 16000 print(f"Streaming audio length: {audio_seconds} seconds") - + asr_content = [] system_prompt = "You are a helpful assistant." asr_content.append({"role": "system", "content": system_prompt}) @@ -185,12 +209,24 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N streaming_time_beg = time.time() inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare( - [asr_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True + [asr_content], + None, + "test_demo", + tokenizer, + frontend, + device=device, + infer_with_assistant_input=True, ) model_asr_inputs = {} model_asr_inputs["inputs_embeds"] = inputs_asr_embeds inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare( - [s2tt_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True + [s2tt_content], + None, + "test_demo", + tokenizer, + frontend, + device=device, + infer_with_assistant_input=True, ) model_s2tt_inputs = {} model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds @@ -208,7 +244,7 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N s2tt_generation_kwargs.update(llm_kwargs) s2tt_thread = Thread(target=model.llm.generate, kwargs=s2tt_generation_kwargs) s2tt_thread.start() - + onscreen_asr_res = previous_asr_text onscreen_s2tt_res = previous_s2tt_text @@ -220,8 +256,8 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N is_s2tt_repetition = False for new_asr_text in asr_streamer: - print(f"generated new asr text: {new_asr_text}") - + current_time = datetime.now() + print("DEBUG: " + str(current_time) + " " + f"generated new asr text: {new_asr_text}") if len(new_asr_text) > 0: onscreen_asr_res += new_asr_text.replace("<|im_end|>", "") if len(new_asr_text.replace("<|im_end|>", "")) > 0: @@ -233,7 +269,13 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N if remain_s2tt_text: try: new_s2tt_text = next(s2tt_streamer) - print(f"generated new s2tt text: {new_s2tt_text}") + current_time = datetime.now() + print( + "DEBUG: " + + str(current_time) + + " " + + f"generated new s2tt text: {new_s2tt_text}" + ) s2tt_iter_cnt += 1 if len(new_s2tt_text) > 0: onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "") @@ -241,16 +283,16 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N new_s2tt_text = "" remain_s2tt_text = False pass - + if len(new_asr_text) > 0 or len(new_s2tt_text) > 0: all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text - unfix_asr_part = all_asr_res[len(fix_asr_part):] - return_asr_res = fix_asr_part + ""+ unfix_asr_part + "" + unfix_asr_part = all_asr_res[len(fix_asr_part) :] + return_asr_res = fix_asr_part + "" + unfix_asr_part + "" all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text - unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):] - return_s2tt_res = fix_s2tt_part + ""+ unfix_s2tt_part + "" + unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part) :] + return_s2tt_res = fix_s2tt_part + "" + unfix_s2tt_part + "" message = json.dumps( { "mode": "online", @@ -261,14 +303,19 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N } ) await websocket.send(message) - websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res - websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + websocket.streaming_state["onscreen_asr_res"] = ( + previous_vad_onscreen_asr_text + onscreen_asr_res + ) + websocket.streaming_state["onscreen_s2tt_res"] = ( + previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + ) - if remain_s2tt_text: for new_s2tt_text in s2tt_streamer: - print(f"generated new s2tt text: {new_s2tt_text}") - + current_time = datetime.now() + print( + "DEBUG: " + str(current_time) + " " + f"generated new s2tt text: {new_s2tt_text}" + ) if len(new_s2tt_text) > 0: onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "") if len(new_s2tt_text.replace("<|im_end|>", "")) > 0: @@ -276,16 +323,16 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N if s2tt_iter_cnt > MAX_ITER_PER_CHUNK: is_s2tt_repetition = True break - + if len(new_s2tt_text) > 0: all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text - unfix_asr_part = all_asr_res[len(fix_asr_part):] - return_asr_res = fix_asr_part + ""+ unfix_asr_part + "" + unfix_asr_part = all_asr_res[len(fix_asr_part) :] + return_asr_res = fix_asr_part + "" + unfix_asr_part + "" all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text - unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):] - return_s2tt_res = fix_s2tt_part + ""+ unfix_s2tt_part + "" + unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part) :] + return_s2tt_res = fix_s2tt_part + "" + unfix_s2tt_part + "" message = json.dumps( { "mode": "online", @@ -296,8 +343,12 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N } ) await websocket.send(message) - websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res - websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + websocket.streaming_state["onscreen_asr_res"] = ( + previous_vad_onscreen_asr_text + onscreen_asr_res + ) + websocket.streaming_state["onscreen_s2tt_res"] = ( + previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + ) streaming_time_end = time.time() print(f"Streaming inference time: {streaming_time_end - streaming_time_beg}") @@ -307,16 +358,24 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_asr_repetition: pre_previous_asr_text = previous_asr_text - previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN]).replace("�", "") + previous_asr_text = tokenizer.decode( + tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN] + ).replace("�", "") if len(previous_asr_text) <= len(pre_previous_asr_text): previous_asr_text = pre_previous_asr_text elif is_asr_repetition: pass else: previous_asr_text = "" - if s2tt_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_s2tt_repetition: + if ( + s2tt_text_len > UNFIX_LEN + and audio_seconds > MIN_LEN_SEC_AUDIO_FIX + and not is_s2tt_repetition + ): pre_previous_s2tt_text = previous_s2tt_text - previous_s2tt_text = tokenizer.decode(tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN]).replace("�", "") + previous_s2tt_text = tokenizer.decode( + tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN] + ).replace("�", "") if len(previous_s2tt_text) <= len(pre_previous_s2tt_text): previous_s2tt_text = pre_previous_s2tt_text elif is_s2tt_repetition: @@ -325,9 +384,13 @@ async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=N previous_s2tt_text = "" websocket.streaming_state["previous_asr_text"] = previous_asr_text - websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res + websocket.streaming_state["onscreen_asr_res"] = ( + previous_vad_onscreen_asr_text + onscreen_asr_res + ) websocket.streaming_state["previous_s2tt_text"] = previous_s2tt_text - websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + websocket.streaming_state["onscreen_s2tt_res"] = ( + previous_vad_onscreen_s2tt_text + onscreen_s2tt_res + ) print("fix asr part:", previous_asr_text) print("fix s2tt part:", previous_s2tt_text) @@ -383,6 +446,13 @@ async def ws_serve(websocket, path): try: async for message in websocket: + if isinstance(message, str): + current_time = datetime.now() + print("DEBUG:" + str(current_time) + " received message:", message) + else: + current_time = datetime.now() + print("DEBUG:" + str(current_time) + " received audio bytes:") + if isinstance(message, str): messagejson = json.loads(message) @@ -403,11 +473,11 @@ async def ws_serve(websocket, path): if "asr_prompt" in messagejson: asr_prompt = messagejson["asr_prompt"] else: - asr_prompt = "Copy:" + asr_prompt = "Speech transcription:" if "s2tt_prompt" in messagejson: s2tt_prompt = messagejson["s2tt_prompt"] else: - s2tt_prompt = "Translate the following sentence into English:" + s2tt_prompt = "Translate into English:" websocket.status_dict_vad["chunk_size"] = int( chunk_size[1] * 60 / websocket.chunk_interval @@ -421,19 +491,20 @@ async def ws_serve(websocket, path): # asr online websocket.streaming_state["is_final"] = speech_end_i != -1 if ( - (len(frames_asr) % websocket.chunk_interval == 0 - or websocket.streaming_state["is_final"]) - and len(frames_asr) != 0 - ): + len(frames_asr) % websocket.chunk_interval == 0 + or websocket.streaming_state["is_final"] + ) and len(frames_asr) != 0: audio_in = b"".join(frames_asr) try: - await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt) + await streaming_transcribe( + websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt + ) except Exception as e: print(f"error in streaming, {e}") print(f"error in streaming, {websocket.streaming_state}") if speech_start: frames_asr.append(message) - + # vad online try: speech_start_i, speech_end_i = await async_vad(websocket, message) @@ -450,7 +521,9 @@ async def ws_serve(websocket, path): if speech_end_i != -1 or not websocket.is_speaking: audio_in = b"".join(frames_asr) try: - await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt) + await streaming_transcribe( + websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt + ) except Exception as e: print(f"error in streaming, {e}") print(f"error in streaming, {websocket.streaming_state}") @@ -460,16 +533,37 @@ async def ws_serve(websocket, path): websocket.streaming_state["previous_s2tt_text"] = "" now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "") now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "") - if len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH: - if now_onscreen_asr_res.endswith(".") or now_onscreen_asr_res.endswith("?") or now_onscreen_asr_res.endswith("!"): + if ( + len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1])) + < MIN_LEN_PER_PARAGRAPH + or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1])) + < MIN_LEN_PER_PARAGRAPH + ): + if ( + now_onscreen_asr_res.endswith(".") + or now_onscreen_asr_res.endswith("?") + or now_onscreen_asr_res.endswith("!") + ): now_onscreen_asr_res += " " - if now_onscreen_s2tt_res.endswith(".") or now_onscreen_s2tt_res.endswith("?") or now_onscreen_s2tt_res.endswith("!"): + if ( + now_onscreen_s2tt_res.endswith(".") + or now_onscreen_s2tt_res.endswith("?") + or now_onscreen_s2tt_res.endswith("!") + ): now_onscreen_s2tt_res += " " - websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res - websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res + websocket.streaming_state["previous_vad_onscreen_asr_text"] = ( + now_onscreen_asr_res + ) + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ( + now_onscreen_s2tt_res + ) else: - websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res + "\n\n" - websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res + "\n\n" + websocket.streaming_state["previous_vad_onscreen_asr_text"] = ( + now_onscreen_asr_res + "\n\n" + ) + websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ( + now_onscreen_s2tt_res + "\n\n" + ) if not websocket.is_speaking: websocket.vad_pre_idx = 0 frames = [] @@ -491,6 +585,8 @@ async def ws_serve(websocket, path): async def async_vad(websocket, audio_in): + current_time = datetime.now() + print("DEBUG:" + str(current_time) + " call vad function:") segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"] # print(segments_result) @@ -506,7 +602,7 @@ async def async_vad(websocket, audio_in): return speech_start, speech_end -if len(args.certfile) > 0: +if False: ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions