feat: Resolve conflict, auto committed by CodeFlow

This commit is contained in:
zhifu.gzf 2024-09-03 13:54:42 +08:00
commit 963472437c
2 changed files with 567 additions and 3 deletions

View File

@ -49,6 +49,17 @@ from queue import Queue
voices = Queue()
class Colors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m' # 重置颜色
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
async def record_microphone():
is_finished = False
@ -185,12 +196,24 @@ async def message(id):
asr_text = meg["asr_text"]
s2tt_text = meg["s2tt_text"]
if prev_asr_text.startswith(asr_text) and prev_s2tt_text.startswith(s2tt_text):
continue
clean_prev_asr_text = prev_asr_text.replace("<em>", "").replace("</em>", "")
clean_prev_s2tt_text = prev_s2tt_text.replace("<em>", "").replace("</em>", "")
clean_asr_text = asr_text.replace("<em>", "").replace("</em>", "")
clean_s2tt_text = s2tt_text.replace("<em>", "").replace("</em>", "")
if clean_prev_asr_text.startswith(clean_asr_text):
new_asr_unfix_pos = asr_text.find("<em>")
asr_text = clean_prev_asr_text[:new_asr_unfix_pos] + "<em>" + clean_prev_asr_text[new_asr_unfix_pos:] + "</em>"
if clean_prev_s2tt_text.startswith(clean_s2tt_text):
new_s2tt_unfix_pos = s2tt_text.find("<em>")
s2tt_text = clean_prev_s2tt_text[:new_s2tt_unfix_pos] + "<em>" + clean_prev_s2tt_text[new_s2tt_unfix_pos:] + "</em>"
prev_asr_text = asr_text
prev_s2tt_text = s2tt_text
text_print = "\n\n" + "ASR: " + asr_text + "\n\n" + "S2TT: " + s2tt_text
print_asr_text = Colors.OKGREEN + asr_text[:asr_text.find("<em>")] + Colors.ENDC + Colors.OKCYAN + asr_text[asr_text.find("<em>") + len("<em>"): -len("</em>")] + Colors.ENDC
print_s2tt_text = Colors.OKGREEN + s2tt_text[:s2tt_text.find("<em>")] + Colors.ENDC + Colors.OKCYAN + s2tt_text[s2tt_text.find("<em>") + len("<em>"): -len("</em>")] + Colors.ENDC
text_print = "\n\n" + "ASR: " + print_asr_text + "\n\n" + "S2TT: " + print_s2tt_text
os.system("clear")
print("\rpid" + str(id) + ": " + text_print)

View File

@ -0,0 +1,541 @@
import os
import asyncio
import json
import websockets
import time
from datetime import datetime
import argparse
import ssl
import numpy as np
from threading import Thread
from transformers import TextIteratorStreamer
from funasr import AutoModel
from modelscope.hub.api import HubApi
from modelscope.hub.snapshot_download import snapshot_download
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument(
"--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="model from modelscope",
)
parser.add_argument("--vad_model_revision", type=str, default="master", help="")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
parser.add_argument(
"--certfile",
type=str,
default="",
required=False,
help="certfile for ssl",
)
parser.add_argument(
"--keyfile",
type=str,
default="ssl_key/server.key",
required=False,
help="keyfile for ssl",
)
args = parser.parse_args()
websocket_users = set()
print("model loading")
# vad
model_vad = AutoModel(
model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
max_single_segment_time=40000,
max_end_silence_time=580,
# chunk_size=60,
)
api = HubApi()
if "key" in os.environ:
key = os.environ["key"]
api.login(key)
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
# llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
device = "cuda:0"
all_file_paths = [
"FunAudioLLM/qwen2_7b_mmt_v14_20240830",
"FunAudioLLM/audiolm_v11_20240807",
"FunAudioLLM/Speech2Text_Align_V0712",
"FunAudioLLM/Speech2Text_Align_V0718",
"FunAudioLLM/Speech2Text_Align_V0628",
]
llm_kwargs = {"num_beams": 1, "do_sample": False}
UNFIX_LEN = 5
MIN_LEN_PER_PARAGRAPH = 25
MIN_LEN_SEC_AUDIO_FIX = 1.1
MAX_ITER_PER_CHUNK = 20
ckpt_dir = all_file_paths[0]
def contains_lora_folder(directory):
for name in os.listdir(directory):
full_path = os.path.join(directory, name)
if os.path.isdir(full_path) and "lora" in name:
return full_path
return None
lora_folder = contains_lora_folder(ckpt_dir)
if lora_folder is not None:
model_llm = AutoModel(
model=ckpt_dir,
device=device,
fp16=False,
bf16=False,
llm_dtype="bf16",
max_length=1024,
llm_kwargs=llm_kwargs,
llm_conf={"init_param_path": llm_dir, "lora_conf": {"init_param_path": lora_folder}, "load_kwargs": {"attn_implementation": "eager"}},
tokenizer_conf={"init_param_path": llm_dir},
audio_encoder=audio_encoder_dir,
)
else:
model_llm = AutoModel(
model=ckpt_dir,
device=device,
fp16=False,
bf16=False,
llm_dtype="bf16",
max_length=1024,
llm_kwargs=llm_kwargs,
llm_conf={"init_param_path": llm_dir, "load_kwargs": {"attn_implementation": "eager"}},
tokenizer_conf={"init_param_path": llm_dir},
audio_encoder=audio_encoder_dir,
)
model = model_llm.model
frontend = model_llm.kwargs["frontend"]
tokenizer = model_llm.kwargs["tokenizer"]
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
print("model loaded! only support one client at the same time now!!!!")
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in "iu":
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype("float32")
if dtype.kind != "f":
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
async def streaming_transcribe(websocket, audio_in, his_state=None, asr_prompt=None, s2tt_prompt=None):
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " call streaming_transcribe function:")
if his_state is None:
his_state = model_dict
model = his_state["model"]
tokenizer = his_state["tokenizer"]
if websocket.streaming_state is None:
previous_asr_text = ""
previous_s2tt_text = ""
previous_vad_onscreen_asr_text = ""
previous_vad_onscreen_s2tt_text = ""
else:
previous_asr_text = websocket.streaming_state.get("previous_asr_text", "")
previous_s2tt_text = websocket.streaming_state.get("previous_s2tt_text", "")
previous_vad_onscreen_asr_text = websocket.streaming_state.get("previous_vad_onscreen_asr_text", "")
previous_vad_onscreen_s2tt_text = websocket.streaming_state.get("previous_vad_onscreen_s2tt_text", "")
if asr_prompt is None or asr_prompt == "":
asr_prompt = "Copy:"
if s2tt_prompt is None or s2tt_prompt == "":
s2tt_prompt = "Translate the following sentence into English:"
audio_seconds = load_bytes(audio_in).shape[0] / 16000
print(f"Streaming audio length: {audio_seconds} seconds")
asr_content = []
system_prompt = "You are a helpful assistant."
asr_content.append({"role": "system", "content": system_prompt})
s2tt_content = []
system_prompt = "You are a helpful assistant."
s2tt_content.append({"role": "system", "content": system_prompt})
user_asr_prompt = f"{asr_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_asr_text}"
user_s2tt_prompt = f"{s2tt_prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_s2tt_text}"
asr_content.append({"role": "user", "content": user_asr_prompt, "audio": audio_in})
asr_content.append({"role": "assistant", "content": "target_out"})
s2tt_content.append({"role": "user", "content": user_s2tt_prompt, "audio": audio_in})
s2tt_content.append({"role": "assistant", "content": "target_out"})
streaming_time_beg = time.time()
inputs_asr_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[asr_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
)
model_asr_inputs = {}
model_asr_inputs["inputs_embeds"] = inputs_asr_embeds
inputs_s2tt_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[s2tt_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
)
model_s2tt_inputs = {}
model_s2tt_inputs["inputs_embeds"] = inputs_s2tt_embeds
print("previous_asr_text:", previous_asr_text)
print("previous_s2tt_text:", previous_s2tt_text)
asr_streamer = TextIteratorStreamer(tokenizer)
asr_generation_kwargs = dict(model_asr_inputs, streamer=asr_streamer, max_new_tokens=1024)
asr_generation_kwargs.update(llm_kwargs)
asr_thread = Thread(target=model.llm.generate, kwargs=asr_generation_kwargs)
asr_thread.start()
s2tt_streamer = TextIteratorStreamer(tokenizer)
s2tt_generation_kwargs = dict(model_s2tt_inputs, streamer=s2tt_streamer, max_new_tokens=1024)
s2tt_generation_kwargs.update(llm_kwargs)
s2tt_thread = Thread(target=model.llm.generate, kwargs=s2tt_generation_kwargs)
s2tt_thread.start()
onscreen_asr_res = previous_asr_text
onscreen_s2tt_res = previous_s2tt_text
remain_s2tt_text = True
asr_iter_cnt = 0
s2tt_iter_cnt = 0
is_asr_repetition = False
is_s2tt_repetition = False
for new_asr_text in asr_streamer:
current_time = datetime.now()
print("DEBUG: " + str(current_time) + " " + f"generated new asr text {new_asr_text}")
if len(new_asr_text) > 0:
onscreen_asr_res += new_asr_text.replace("<|im_end|>", "")
if len(new_asr_text.replace("<|im_end|>", "")) > 0:
asr_iter_cnt += 1
if asr_iter_cnt > MAX_ITER_PER_CHUNK:
is_asr_repetition = True
break
if remain_s2tt_text:
try:
new_s2tt_text = next(s2tt_streamer)
current_time = datetime.now()
print("DEBUG: " + str(current_time) + " " + f"generated new s2tt text {new_s2tt_text}")
s2tt_iter_cnt += 1
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
except StopIteration:
new_s2tt_text = ""
remain_s2tt_text = False
pass
if len(new_asr_text) > 0 or len(new_s2tt_text) > 0:
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text
unfix_asr_part = all_asr_res[len(fix_asr_part):]
return_asr_res = fix_asr_part + "<em>"+ unfix_asr_part + "</em>"
all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):]
return_s2tt_res = fix_s2tt_part + "<em>"+ unfix_s2tt_part + "</em>"
message = json.dumps(
{
"mode": "online",
"asr_text": return_asr_res,
"s2tt_text": return_s2tt_res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
if remain_s2tt_text:
for new_s2tt_text in s2tt_streamer:
current_time = datetime.now()
print("DEBUG: " + str(current_time) + " " + f"generated new s2tt text {new_s2tt_text}")
if len(new_s2tt_text) > 0:
onscreen_s2tt_res += new_s2tt_text.replace("<|im_end|>", "")
if len(new_s2tt_text.replace("<|im_end|>", "")) > 0:
s2tt_iter_cnt += 1
if s2tt_iter_cnt > MAX_ITER_PER_CHUNK:
is_s2tt_repetition = True
break
if len(new_s2tt_text) > 0:
all_asr_res = previous_vad_onscreen_asr_text + onscreen_asr_res
fix_asr_part = previous_vad_onscreen_asr_text + previous_asr_text
unfix_asr_part = all_asr_res[len(fix_asr_part):]
return_asr_res = fix_asr_part + "<em>"+ unfix_asr_part + "</em>"
all_s2tt_res = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
fix_s2tt_part = previous_vad_onscreen_s2tt_text + previous_s2tt_text
unfix_s2tt_part = all_s2tt_res[len(fix_s2tt_part):]
return_s2tt_res = fix_s2tt_part + "<em>"+ unfix_s2tt_part + "</em>"
message = json.dumps(
{
"mode": "online",
"asr_text": return_asr_res,
"s2tt_text": return_s2tt_res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
streaming_time_end = time.time()
print(f"Streaming inference time: {streaming_time_end - streaming_time_beg}")
asr_text_len = len(tokenizer.encode(onscreen_asr_res))
s2tt_text_len = len(tokenizer.encode(onscreen_s2tt_res))
if asr_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_asr_repetition:
pre_previous_asr_text = previous_asr_text
previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-UNFIX_LEN])
if len(previous_asr_text) < len(pre_previous_asr_text):
previous_asr_text = pre_previous_asr_text
elif is_asr_repetition:
pass
else:
previous_asr_text = ""
if s2tt_text_len > UNFIX_LEN and audio_seconds > MIN_LEN_SEC_AUDIO_FIX and not is_s2tt_repetition:
pre_previous_s2tt_text = previous_s2tt_text
previous_s2tt_text = tokenizer.decode(tokenizer.encode(onscreen_s2tt_res)[:-UNFIX_LEN])
if len(previous_s2tt_text) < len(pre_previous_s2tt_text):
previous_s2tt_text = pre_previous_s2tt_text
elif is_s2tt_repetition:
pass
else:
previous_s2tt_text = ""
websocket.streaming_state["previous_asr_text"] = previous_asr_text
websocket.streaming_state["onscreen_asr_res"] = previous_vad_onscreen_asr_text + onscreen_asr_res
websocket.streaming_state["previous_s2tt_text"] = previous_s2tt_text
websocket.streaming_state["onscreen_s2tt_res"] = previous_vad_onscreen_s2tt_text + onscreen_s2tt_res
print("fix asr part:", previous_asr_text)
print("fix s2tt part:", previous_s2tt_text)
async def ws_reset(websocket):
print("ws reset now, total num is ", len(websocket_users))
websocket.streaming_state = {}
websocket.streaming_state["is_final"] = True
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
websocket.streaming_state["onscreen_asr_res"] = ""
websocket.streaming_state["onscreen_s2tt_res"] = ""
websocket.streaming_state["previous_vad_onscreen_asr_text"] = ""
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = ""
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
await websocket.close()
async def clear_websocket():
for websocket in websocket_users:
await ws_reset(websocket)
websocket_users.clear()
async def ws_serve(websocket, path):
frames = []
frames_asr = []
global websocket_users
# await clear_websocket()
websocket_users.add(websocket)
websocket.streaming_state = {
"previous_asr_text": "",
"previous_s2tt_text": "",
"onscreen_asr_res": "",
"onscreen_s2tt_res": "",
"previous_vad_onscreen_asr_text": "",
"previous_vad_onscreen_s2tt_text": "",
"is_final": False,
}
websocket.status_dict_vad = {"cache": {}, "is_final": False}
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
speech_start = False
speech_end_i = -1
websocket.wav_name = "microphone"
print("new user connected", flush=True)
try:
async for message in websocket:
if isinstance(message, str):
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " received message:", message)
else:
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " received audio bytes:")
if isinstance(message, str):
messagejson = json.loads(message)
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.streaming_state["is_final"] = not websocket.is_speaking
if not messagejson["is_speaking"]:
await clear_websocket()
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
chunk_size = [int(x) for x in chunk_size]
if "asr_prompt" in messagejson:
asr_prompt = messagejson["asr_prompt"]
else:
asr_prompt = "Copy:"
if "s2tt_prompt" in messagejson:
s2tt_prompt = messagejson["s2tt_prompt"]
else:
s2tt_prompt = "Translate the following sentence into English:"
websocket.status_dict_vad["chunk_size"] = int(
chunk_size[1] * 60 / websocket.chunk_interval
)
if len(frames_asr) > 0 or not isinstance(message, str):
if not isinstance(message, str):
frames.append(message)
duration_ms = len(message) // 32
websocket.vad_pre_idx += duration_ms
# asr online
websocket.streaming_state["is_final"] = speech_end_i != -1
if (
(len(frames_asr) % websocket.chunk_interval == 0
or websocket.streaming_state["is_final"])
and len(frames_asr) != 0
):
audio_in = b"".join(frames_asr)
try:
await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt)
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
if speech_start:
frames_asr.append(message)
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
except:
print("error in vad")
if speech_start_i != -1:
speech_start = True
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
# vad end
if speech_end_i != -1 or not websocket.is_speaking:
audio_in = b"".join(frames_asr)
try:
await streaming_transcribe(websocket, audio_in, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt)
except Exception as e:
print(f"error in streaming, {e}")
print(f"error in streaming, {websocket.streaming_state}")
frames_asr = []
speech_start = False
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
now_onscreen_asr_res = websocket.streaming_state.get("onscreen_asr_res", "")
now_onscreen_s2tt_res = websocket.streaming_state.get("onscreen_s2tt_res", "")
if len(tokenizer.encode(now_onscreen_asr_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH or len(tokenizer.encode(now_onscreen_s2tt_res.split("\n\n")[-1])) < MIN_LEN_PER_PARAGRAPH:
if now_onscreen_asr_res.endswith(".") or now_onscreen_asr_res.endswith("?") or now_onscreen_asr_res.endswith("!"):
now_onscreen_asr_res += " "
if now_onscreen_s2tt_res.endswith(".") or now_onscreen_s2tt_res.endswith("?") or now_onscreen_s2tt_res.endswith("!"):
now_onscreen_s2tt_res += " "
websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res
else:
websocket.streaming_state["previous_vad_onscreen_asr_text"] = now_onscreen_asr_res + "\n\n"
websocket.streaming_state["previous_vad_onscreen_s2tt_text"] = now_onscreen_s2tt_res + "\n\n"
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
websocket.streaming_state["previous_asr_text"] = ""
websocket.streaming_state["previous_s2tt_text"] = ""
else:
frames = frames[-20:]
else:
print(f"message: {message}")
except websockets.ConnectionClosed:
print("ConnectionClosed...", websocket_users, flush=True)
await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("InvalidState...")
except Exception as e:
print("Exception:", e)
async def async_vad(websocket, audio_in):
current_time = datetime.now()
print("DEBUG:" + str(current_time) + " call vad function:")
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
# print(segments_result)
speech_start = -1
speech_end = -1
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
if False:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
ssl_cert = args.certfile
ssl_key = args.keyfile
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
start_server = websockets.serve(
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None, ssl=ssl_context
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()