streaming asr

This commit is contained in:
yangyexin.yyx 2024-08-28 09:48:06 +08:00
parent 032d429a94
commit 366603d4ed
2 changed files with 778 additions and 0 deletions

View File

@ -0,0 +1,394 @@
# -*- encoding: utf-8 -*-
import os
import time
import websockets, ssl
import asyncio
# import threading
import argparse
import json
import traceback
from multiprocessing import Process
# from funasr.fileio.datadir_writer import DatadirWriter
import logging
logging.basicConfig(level=logging.ERROR)
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="localhost", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument("--chunk_size", type=str, default="5, 10, 5", help="chunk")
parser.add_argument("--encoder_chunk_look_back", type=int, default=4, help="chunk")
parser.add_argument("--decoder_chunk_look_back", type=int, default=0, help="chunk")
parser.add_argument("--chunk_interval", type=int, default=10, help="chunk")
parser.add_argument(
"--hotword",
type=str,
default="",
help="hotword file path, one hotword perline (e.g.:阿里巴巴 20)",
)
parser.add_argument("--audio_in", type=str, default=None, help="audio_in")
parser.add_argument("--audio_fs", type=int, default=16000, help="audio_fs")
parser.add_argument(
"--send_without_sleep",
action="store_true",
default=True,
help="if audio_in is set, send_without_sleep",
)
parser.add_argument("--thread_num", type=int, default=1, help="thread_num")
parser.add_argument("--words_max_print", type=int, default=10000, help="chunk")
parser.add_argument("--output_dir", type=str, default=None, help="output_dir")
parser.add_argument("--ssl", type=int, default=1, help="1 for ssl connect, 0 for no ssl")
parser.add_argument("--use_itn", type=int, default=1, help="1 for using itn, 0 for not itn")
parser.add_argument("--mode", type=str, default="2pass", help="offline, online, 2pass")
args = parser.parse_args()
args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
print(args)
# voices = asyncio.Queue()
from queue import Queue
voices = Queue()
offline_msg_done = False
if args.output_dir is not None:
# if os.path.exists(args.output_dir):
# os.remove(args.output_dir)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
async def record_microphone():
is_finished = False
import pyaudio
# print("2")
global voices
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 16000
chunk_size = 60 * args.chunk_size[1] / args.chunk_interval
CHUNK = int(RATE / 1000 * chunk_size)
p = pyaudio.PyAudio()
stream = p.open(
format=FORMAT, channels=CHANNELS, rate=RATE, input=True, frames_per_buffer=CHUNK
)
# hotwords
fst_dict = {}
hotword_msg = ""
if args.hotword.strip() != "":
if os.path.exists(args.hotword):
f_scp = open(args.hotword)
hot_lines = f_scp.readlines()
for line in hot_lines:
words = line.strip().split(" ")
if len(words) < 2:
print("Please checkout format of hotwords")
continue
try:
fst_dict[" ".join(words[:-1])] = int(words[-1])
except ValueError:
print("Please checkout format of hotwords")
hotword_msg = json.dumps(fst_dict)
else:
hotword_msg = args.hotword
use_itn = True
if args.use_itn == 0:
use_itn = False
message = json.dumps(
{
"mode": args.mode,
"chunk_size": args.chunk_size,
"chunk_interval": args.chunk_interval,
"encoder_chunk_look_back": args.encoder_chunk_look_back,
"decoder_chunk_look_back": args.decoder_chunk_look_back,
"wav_name": "microphone",
"is_speaking": True,
"hotwords": hotword_msg,
"itn": use_itn,
}
)
# voices.put(message)
await websocket.send(message)
while True:
data = stream.read(CHUNK)
message = data
# voices.put(message)
await websocket.send(message)
await asyncio.sleep(0.0005)
async def record_from_scp(chunk_begin, chunk_size):
global voices
is_finished = False
if args.audio_in.endswith(".scp"):
f_scp = open(args.audio_in)
wavs = f_scp.readlines()
else:
wavs = [args.audio_in]
# hotwords
fst_dict = {}
hotword_msg = ""
if args.hotword.strip() != "":
if os.path.exists(args.hotword):
f_scp = open(args.hotword)
hot_lines = f_scp.readlines()
for line in hot_lines:
words = line.strip().split(" ")
if len(words) < 2:
print("Please checkout format of hotwords")
continue
try:
fst_dict[" ".join(words[:-1])] = int(words[-1])
except ValueError:
print("Please checkout format of hotwords")
hotword_msg = json.dumps(fst_dict)
else:
hotword_msg = args.hotword
print(hotword_msg)
sample_rate = args.audio_fs
wav_format = "pcm"
use_itn = True
if args.use_itn == 0:
use_itn = False
if chunk_size > 0:
wavs = wavs[chunk_begin : chunk_begin + chunk_size]
for wav in wavs:
wav_splits = wav.strip().split()
wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
if not len(wav_path.strip()) > 0:
continue
if wav_path.endswith(".pcm"):
with open(wav_path, "rb") as f:
audio_bytes = f.read()
elif wav_path.endswith(".wav"):
import wave
with wave.open(wav_path, "rb") as wav_file:
params = wav_file.getparams()
sample_rate = wav_file.getframerate()
frames = wav_file.readframes(wav_file.getnframes())
audio_bytes = bytes(frames)
else:
wav_format = "others"
with open(wav_path, "rb") as f:
audio_bytes = f.read()
stride = int(60 * args.chunk_size[1] / args.chunk_interval / 1000 * sample_rate * 2)
chunk_num = (len(audio_bytes) - 1) // stride + 1
# print(stride)
# send first time
message = json.dumps(
{
"mode": args.mode,
"chunk_size": args.chunk_size,
"chunk_interval": args.chunk_interval,
"encoder_chunk_look_back": args.encoder_chunk_look_back,
"decoder_chunk_look_back": args.decoder_chunk_look_back,
"audio_fs": sample_rate,
"wav_name": wav_name,
"wav_format": wav_format,
"is_speaking": True,
"hotwords": hotword_msg,
"itn": use_itn,
}
)
# voices.put(message)
await websocket.send(message)
is_speaking = True
for i in range(chunk_num):
beg = i * stride
data = audio_bytes[beg : beg + stride]
message = data
# voices.put(message)
await websocket.send(message)
if i == chunk_num - 1:
is_speaking = False
message = json.dumps({"is_speaking": is_speaking})
# voices.put(message)
await websocket.send(message)
# sleep_duration = 0.00001 # 60 * args.chunk_size[1] / args.chunk_interval / 1000
sleep_duration = 60 * args.chunk_size[1] / args.chunk_interval / 1000
await asyncio.sleep(sleep_duration)
if not args.mode == "offline":
await asyncio.sleep(2)
# offline model need to wait for message recved
if args.mode == "offline":
global offline_msg_done
while not offline_msg_done:
await asyncio.sleep(1)
await websocket.close()
async def message(id):
global websocket, voices, offline_msg_done
text_print = ""
text_print_2pass_online = ""
text_print_2pass_offline = ""
if args.output_dir is not None:
ibest_writer = open(
os.path.join(args.output_dir, "text.{}".format(id)), "a", encoding="utf-8"
)
else:
ibest_writer = None
try:
while True:
meg = await websocket.recv()
meg = json.loads(meg)
wav_name = meg.get("wav_name", "demo")
text = meg["text"]
timestamp = ""
offline_msg_done = meg.get("is_final", False)
if "timestamp" in meg:
timestamp = meg["timestamp"]
if ibest_writer is not None:
if timestamp != "":
text_write_line = "{}\t{}\t{}\n".format(wav_name, text, timestamp)
else:
text_write_line = "{}\t{}\n".format(wav_name, text)
ibest_writer.write(text_write_line)
if "mode" not in meg:
continue
if meg["mode"] == "online":
text_print = text
os.system("clear")
print("\rpid" + str(id) + ": " + text_print)
elif meg["mode"] == "offline":
if timestamp != "":
text_print += "{} timestamp: {}".format(text, timestamp)
else:
text_print += "{}".format(text)
# text_print = text_print[-args.words_max_print:]
# os.system('clear')
print("\rpid" + str(id) + ": " + wav_name + ": " + text_print)
offline_msg_done = True
else:
if meg["mode"] == "2pass-online":
text_print_2pass_online += "{}".format(text)
text_print = text_print_2pass_offline + text_print_2pass_online
else:
text_print_2pass_online = ""
text_print = text_print_2pass_offline + "{}".format(text)
text_print_2pass_offline += "{}".format(text)
text_print = text_print[-args.words_max_print :]
os.system("clear")
print("\rpid" + str(id) + ": " + text_print)
# offline_msg_done=True
except Exception as e:
print("Exception:", e)
# traceback.print_exc()
# await websocket.close()
async def ws_client(id, chunk_begin, chunk_size):
if args.audio_in is None:
chunk_begin = 0
chunk_size = 1
global websocket, voices, offline_msg_done
for i in range(chunk_begin, chunk_begin + chunk_size):
offline_msg_done = False
voices = Queue()
if args.ssl == 1:
ssl_context = ssl.SSLContext()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
uri = "wss://{}:{}".format(args.host, args.port)
else:
uri = "ws://{}:{}".format(args.host, args.port)
ssl_context = None
print("connect to", uri)
async with websockets.connect(
uri, subprotocols=["binary"], ping_interval=None, ssl=ssl_context
) as websocket:
if args.audio_in is not None:
task = asyncio.create_task(record_from_scp(i, 1))
else:
task = asyncio.create_task(record_microphone())
task3 = asyncio.create_task(message(str(id) + "_" + str(i))) # processid+fileid
await asyncio.gather(task, task3)
exit(0)
def one_thread(id, chunk_begin, chunk_size):
asyncio.get_event_loop().run_until_complete(ws_client(id, chunk_begin, chunk_size))
asyncio.get_event_loop().run_forever()
if __name__ == "__main__":
# for microphone
if args.audio_in is None:
p = Process(target=one_thread, args=(0, 0, 0))
p.start()
p.join()
print("end")
else:
# calculate the number of wavs for each preocess
if args.audio_in.endswith(".scp"):
f_scp = open(args.audio_in)
wavs = f_scp.readlines()
else:
wavs = [args.audio_in]
for wav in wavs:
wav_splits = wav.strip().split()
wav_name = wav_splits[0] if len(wav_splits) > 1 else "demo"
wav_path = wav_splits[1] if len(wav_splits) > 1 else wav_splits[0]
audio_type = os.path.splitext(wav_path)[-1].lower()
total_len = len(wavs)
if total_len >= args.thread_num:
chunk_size = int(total_len / args.thread_num)
remain_wavs = total_len - chunk_size * args.thread_num
else:
chunk_size = 1
remain_wavs = 0
process_list = []
chunk_begin = 0
for i in range(args.thread_num):
now_chunk_size = chunk_size
if remain_wavs > 0:
now_chunk_size = chunk_size + 1
remain_wavs = remain_wavs - 1
# process i handle wavs at chunk_begin and size of now_chunk_size
p = Process(target=one_thread, args=(i, chunk_begin, now_chunk_size))
chunk_begin = chunk_begin + now_chunk_size
p.start()
process_list.append(p)
for i in process_list:
p.join()
print("end")
"""
python funasr_wss_client.py --host "127.0.0.1" --port 10095 --audio_in audio_file
"""

View File

@ -0,0 +1,384 @@
import asyncio
import json
import websockets
import time
import logging
import tracemalloc
import numpy as np
import argparse
import ssl
import os
import torch
import torchaudio
from transformers import TextIteratorStreamer
from threading import Thread
import traceback
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument(
"--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model from modelscope",
)
parser.add_argument("--asr_model_revision", type=str, default="master", help="")
parser.add_argument(
"--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="model from modelscope",
)
parser.add_argument("--asr_model_online_revision", type=str, default="master", help="")
parser.add_argument(
"--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="model from modelscope",
)
parser.add_argument("--vad_model_revision", type=str, default="master", help="")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
parser.add_argument(
"--certfile",
type=str,
default="../../ssl_key/server.crt",
required=False,
help="certfile for ssl",
)
parser.add_argument(
"--keyfile",
type=str,
default="../../ssl_key/server.key",
required=False,
help="keyfile for ssl",
)
args = parser.parse_args()
websocket_users = set()
print("model loading")
from funasr import AutoModel
# vad
model_vad = AutoModel(
model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
# chunk_size=60,
)
from funasr import AutoModel
from modelscope.hub.api import HubApi
api = HubApi()
if "key" in os.environ:
key = os.environ["key"]
api.login(key)
from modelscope.hub.snapshot_download import snapshot_download
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
device = "cuda:0"
all_file_paths = [
"/nfs/yangyexin.yyx/init_model/audiolm_v14_20240824_train_encoder_all_20240822_lr1e-4_warmup2350/"
]
llm_kwargs = {"num_beams": 1, "do_sample": False}
unfix_len = 5
max_streaming_res_onetime = 100
ckpt_dir = all_file_paths[0]
model_llm = AutoModel(
model=ckpt_dir,
device=device,
fp16=False,
bf16=False,
llm_dtype="bf16",
max_length=1024,
llm_kwargs=llm_kwargs,
llm_conf={"init_param_path": llm_dir},
tokenizer_conf={"init_param_path": llm_dir},
audio_encoder=audio_encoder_dir,
)
model = model_llm.model
frontend = model_llm.kwargs["frontend"]
tokenizer = model_llm.kwargs["tokenizer"]
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in "iu":
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype("float32")
if dtype.kind != "f":
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
async def streaming_transcribe(websocket, audio_in, his_state=None, prompt=None):
if his_state is None:
his_state = model_dict
model = his_state["model"]
tokenizer = his_state["tokenizer"]
if websocket.streaming_state is None:
previous_asr_text = ""
else:
previous_asr_text = websocket.streaming_state["previous_asr_text"]
if prompt is None:
prompt = "Copy:"
audio_seconds = load_bytes(audio_in).shape[0] / 16000
print(f"Streaming audio length: {audio_seconds} seconds")
asr_content = []
system_prompt = "You are a helpful assistant."
asr_content.append({"role": "system", "content": system_prompt})
asr_user_prompt = f"{prompt}<|startofspeech|>!!<|endofspeech|><|im_end|>\n<|im_start|>assistant\n{previous_asr_text}"
asr_content.append({"role": "user", "content": asr_user_prompt, "audio": audio_in})
asr_content.append({"role": "assistant", "content": "target_out"})
streaming_asr_time_beg = time.time()
asr_inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[asr_content], None, "test_demo", tokenizer, frontend, device=device, infer_with_assistant_input=True
)
asr_model_inputs = {}
asr_model_inputs["inputs_embeds"] = asr_inputs_embeds
print("previous_asr_text:", previous_asr_text)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(asr_model_inputs, streamer=streamer, max_new_tokens=1024)
thread = Thread(target=model.llm.generate, kwargs=generation_kwargs)
thread.start()
onscreen_asr_res = previous_asr_text
beg_llm = time.time()
for new_text in streamer:
end_llm = time.time()
print(
f"generated new text {new_text}, time_llm_decode: {end_llm - beg_llm:.2f}"
)
if len(new_text) > 0:
onscreen_asr_res += new_text.replace("<|im_end|>", "")
mode = "online"
message = json.dumps(
{
"mode": mode,
"text": onscreen_asr_res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
await websocket.send(message)
streaming_asr_time_end = time.time()
print(f"Streaming ASR inference time: {streaming_asr_time_end - streaming_asr_time_beg}")
asr_text_len = len(tokenizer.encode(onscreen_asr_res))
if asr_text_len > unfix_len and audio_seconds > 1.1:
if asr_text_len <= max_streaming_res_onetime:
previous_asr_text = tokenizer.decode(tokenizer.encode(onscreen_asr_res)[:-unfix_len])
else:
onscreen_asr_res = previous_asr_text
else:
previous_asr_text = ""
websocket.streaming_state = {}
websocket.streaming_state["previous_asr_text"] = previous_asr_text
print("fix asr part:", previous_asr_text)
print("model loaded! only support one client at the same time now!!!!")
async def ws_reset(websocket):
print("ws reset now, total num is ", len(websocket_users))
websocket.status_dict_asr_online["cache"] = {}
websocket.status_dict_asr_online["is_final"] = True
websocket.streaming_state = None
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
await websocket.close()
async def clear_websocket():
for websocket in websocket_users:
await ws_reset(websocket)
websocket_users.clear()
async def ws_serve(websocket, path):
frames = []
frames_asr = []
global websocket_users
# await clear_websocket()
websocket_users.add(websocket)
websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
websocket.status_dict_vad = {"cache": {}, "is_final": False}
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
speech_start = False
speech_end_i = -1
websocket.wav_name = "microphone"
websocket.mode = "online"
websocket.streaming_state = None
print("new user connected", flush=True)
try:
async for message in websocket:
if isinstance(message, str):
messagejson = json.loads(message)
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson[
"encoder_chunk_look_back"
]
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson[
"decoder_chunk_look_back"
]
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
websocket.mode = messagejson["mode"]
websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online["chunk_size"][1] * 60 / websocket.chunk_interval
)
if len(frames_asr) > 0 or not isinstance(message, str):
if not isinstance(message, str):
frames.append(message)
duration_ms = len(message) // 32
websocket.vad_pre_idx += duration_ms
# asr online
websocket.status_dict_asr_online["is_final"] = speech_end_i != -1
if (
(len(frames_asr) % websocket.chunk_interval == 0
or websocket.status_dict_asr_online["is_final"])
and len(frames_asr) != 0
):
if websocket.mode == "2pass" or websocket.mode == "online":
audio_in = b"".join(frames_asr)
try:
await streaming_transcribe(websocket, audio_in)
except:
print(f"error in asr streaming, {websocket.status_dict_asr_online}")
if speech_start:
frames_asr.append(message)
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
except:
print("error in vad")
if speech_start_i != -1:
speech_start = True
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
if speech_end_i != -1 or not websocket.is_speaking:
frames_asr = []
speech_start = False
websocket.status_dict_asr_online["cache"] = {}
websocket.streaming_state = None
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
websocket.streaming_state = None
else:
frames = frames[-20:]
else:
print(f"message: {message}")
except websockets.ConnectionClosed:
print("ConnectionClosed...", websocket_users, flush=True)
await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("InvalidState...")
except Exception as e:
print("Exception:", e)
async def async_vad(websocket, audio_in):
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
# print(segments_result)
speech_start = -1
speech_end = -1
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
if len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
ssl_cert = args.certfile
ssl_key = args.keyfile
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
start_server = websockets.serve(
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None, ssl=ssl_context
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()