This commit is contained in:
游雁 2024-08-22 11:32:22 +08:00
parent 2d29a079ee
commit 70bdbabcb2
4 changed files with 590 additions and 67 deletions

View File

@ -14,8 +14,8 @@ from threading import Thread
import torch
import time
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
# torch.backends.cuda.enable_mem_efficient_sdp(False)
# torch.backends.cuda.enable_flash_sdp(False)
from funasr import AutoModel

View File

@ -7,6 +7,7 @@ from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
import math
@tables.register("dataset_classes", "OpenAIDataset")
class OpenAIDataset(torch.utils.data.Dataset):
"""
@ -776,6 +777,7 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
return outputs
@tables.register("dataset_classes", "OpenAIDatasetMultiTurnCodecMel")
class OpenAIDatasetMultiTurnCodecMel(torch.utils.data.Dataset):
"""
@ -881,7 +883,8 @@ class OpenAIDatasetMultiTurnCodecMel(torch.utils.data.Dataset):
audio,
audio_len,
input_mask,
) = ([], [], [], [], [], [], [], [], [], [], [], [])
input_mask_beg,
) = ([], [], [], [], [], [], [], [], [], [], [], [], [])
multiturn_num = len(system)
for i, (system_prompt, user_prompt, target_out) in enumerate(
@ -998,6 +1001,19 @@ class OpenAIDatasetMultiTurnCodecMel(torch.utils.data.Dataset):
fake_token_len += [fake_token_len_i]
source_mask = [-100] * len(source_ids)
if i == 0:
sys_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
sys_prompt_len = self.tokenizer.encode(sys_prompt)
input_mask_i = (
[1] * len(sys_prompt_len) + [0] * len(source_ids) + [0] * len(target_ids)
)
else:
input_mask_i = (
[1] * len(input_ids) + [0] * len(source_ids) + [0] * len(target_ids)
)
input_mask_i = torch.tensor(input_mask_i, dtype=torch.int64)
input_mask_beg.append(input_mask_i)
input_mask_i = [1] * len(input_ids) + [1] * len(source_ids) + [0] * len(target_ids)
input_mask_i = torch.tensor(input_mask_i, dtype=torch.int64)
input_mask.append(input_mask_i)
@ -1043,6 +1059,8 @@ class OpenAIDatasetMultiTurnCodecMel(torch.utils.data.Dataset):
output["audio_len"] = audio_len # torch.tensor(audio_len, dtype=torch.int32)
if len(input_mask) > 0:
output["input_mask"] = input_mask
output["input_mask_beg"] = input_mask_beg
if key is not None:
output["key"] = key
break
@ -1420,6 +1438,7 @@ class OpenAIDatasetMultiTurnCodecMel2(torch.utils.data.Dataset):
return outputs
@tables.register("dataset_classes", "OpenAIDatasetMultiTurnForFullDuplexVAD")
class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
"""
@ -1427,14 +1446,14 @@ class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
"""
def __init__(
self,
path,
index_ds: str = None,
frontend=None,
tokenizer=None,
int_pad_value: int = -1,
float_pad_value: float = 0.0,
**kwargs,
self,
path,
index_ds: str = None,
frontend=None,
tokenizer=None,
int_pad_value: int = -1,
float_pad_value: float = 0.0,
**kwargs,
):
super().__init__()
index_ds_class = tables.index_ds_classes.get(index_ds)
@ -1525,7 +1544,7 @@ class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
)
for i, (system_prompt, user_prompt, target_out) in enumerate(
zip(system, user, assistant)
zip(system, user, assistant)
):
if len(input_ids) > self.max_token_length:
logging.info(
@ -1535,10 +1554,8 @@ class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
if i == 0:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
elif i == len(system)-1:
source_input = (
f"<|im_start|>user\n{user_prompt}"
)
elif i == len(system) - 1:
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
@ -1614,22 +1631,26 @@ class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
turn_taking_labels = [-100] * len(labels)
barge_in_labels = [-100] * len(labels)
last_vad = [0] * fake_token_len[-1]
pos_vad = math.ceil(fake_token_len[-1] * (true_time_span/last_time_span))
pos_vad = math.ceil(fake_token_len[-1] * (true_time_span / last_time_span))
assert pos_vad <= fake_token_len[-1]
if pos_vad > 0:
last_vad[-pos_vad:] = [1] * pos_vad
if task == "turn-taking":
turn_taking_labels[-fake_token_len[-1]:] = last_vad
turn_taking_labels[-fake_token_len[-1] :] = last_vad
elif task == "barge-in":
# print(f'barge-in: {last_vad}')
barge_in_labels[-fake_token_len[-1]:] = last_vad
barge_in_labels[-fake_token_len[-1] :] = last_vad
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
turn_taking_labels = torch.tensor(turn_taking_labels, dtype=torch.int64) # [: self.max_token_length]
barge_in_labels = torch.tensor(barge_in_labels, dtype=torch.int64) # [: self.max_token_length]
turn_taking_labels = torch.tensor(
turn_taking_labels, dtype=torch.int64
) # [: self.max_token_length]
barge_in_labels = torch.tensor(
barge_in_labels, dtype=torch.int64
) # [: self.max_token_length]
# fbank = speech[0, :, :]
# fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)

View File

@ -1727,6 +1727,8 @@ class LLMASR4_extract_kv(nn.Module):
)
loss = model_outputs.loss
input_mask_beg = kwargs.get("input_mask_beg")
input_mask_beg[input_mask_beg < 0] = 0
input_mask = kwargs.get("input_mask")
input_mask[input_mask < 0] = 0
@ -1737,9 +1739,10 @@ class LLMASR4_extract_kv(nn.Module):
savemat(mat_file, {"kv_cache": hidden_states[0].cpu()})
for turn_id_cum in range(input_mask.shape[0]):
beg = input_mask_beg[turn_id_cum].sum(-1)
end = input_mask[turn_id_cum].sum(-1)
uttid = f"{key}_assistant_{turn_id_cum:02d}"
line = f"{uttid} {mat_file} {end}\n"
line = f"{uttid} {mat_file} {beg} {end}\n"
self.fo.write(line)
self.fo.flush()
@ -5242,6 +5245,7 @@ class LLMASR7(nn.Module):
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
return top_ids
@tables.register("model_classes", "LLMVAD")
class LLMVAD(nn.Module):
""" """
@ -5381,14 +5385,23 @@ class LLMVAD(nn.Module):
print("self.llm.config:", self.llm.config)
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
from copy import deepcopy
self.task_decoder_layer_config = deepcopy(self.llm.config)
self.task_decoder_layer_config.hidden_size = self.llm.config.hidden_size // 4
self.task_decoder_layer_config.intermediate_size = self.llm.config.intermediate_size // 4
self.task_decoder_layer_config.num_attention_heads = self.llm.config.num_attention_heads // 4
self.task_decoder_layer_config.num_key_value_heads = self.llm.config.num_key_value_heads // 4
self.task_decoder_layer_config.num_attention_heads = (
self.llm.config.num_attention_heads // 4
)
self.task_decoder_layer_config.num_key_value_heads = (
self.llm.config.num_key_value_heads // 4
)
print("self.task_decoder_layer_config:", self.task_decoder_layer_config)
self.down_proj = nn.Linear(self.llm.config.hidden_size, self.task_decoder_layer_config.hidden_size, bias=False).to(dtype_map[self.llm_dtype])
self.task_decoder_layer = Qwen2DecoderLayer(self.task_decoder_layer_config, self.llm.config.num_hidden_layers).to(dtype_map[self.llm_dtype])
self.down_proj = nn.Linear(
self.llm.config.hidden_size, self.task_decoder_layer_config.hidden_size, bias=False
).to(dtype_map[self.llm_dtype])
self.task_decoder_layer = Qwen2DecoderLayer(
self.task_decoder_layer_config, self.llm.config.num_hidden_layers
).to(dtype_map[self.llm_dtype])
if getattr(self.llm.config, "classifier_dropout", None) is not None:
classifier_dropout = self.llm.config.classifier_dropout
elif getattr(self.llm.config, "hidden_dropout", None) is not None:
@ -5398,9 +5411,12 @@ class LLMVAD(nn.Module):
self.dropout = nn.Dropout(classifier_dropout)
self.barge_in_num_labels = 2
self.turn_taking_num_labels = 2
self.barge_in_score = nn.Linear(self.task_decoder_layer_config.hidden_size, self.barge_in_num_labels).to(dtype_map[self.llm_dtype])
self.turn_taking_score = nn.Linear(self.task_decoder_layer_config.hidden_size, self.turn_taking_num_labels).to(dtype_map[self.llm_dtype])
self.barge_in_score = nn.Linear(
self.task_decoder_layer_config.hidden_size, self.barge_in_num_labels
).to(dtype_map[self.llm_dtype])
self.turn_taking_score = nn.Linear(
self.task_decoder_layer_config.hidden_size, self.turn_taking_num_labels
).to(dtype_map[self.llm_dtype])
def forward(
self,
@ -5503,17 +5519,24 @@ class LLMVAD(nn.Module):
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, \
_prepare_4d_causal_attention_mask_for_sdpa
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
if self.llm.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
causal_attention_mask = (
attention_mask if (attention_mask is not None and 0 in attention_mask) else None
)
elif self.llm.config._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
@ -5565,11 +5588,16 @@ class LLMVAD(nn.Module):
loss = None
if barge_in_labels is not None:
barge_in_labels[barge_in_labels == -1] = -100
barge_in_loss = self.loss_fct(barge_in_logits.view(-1, self.barge_in_num_labels), barge_in_labels.view(-1))
barge_in_loss = self.loss_fct(
barge_in_logits.view(-1, self.barge_in_num_labels), barge_in_labels.view(-1)
)
loss = barge_in_loss
if turn_taking_labels is not None:
turn_taking_labels[turn_taking_labels == -1] = -100
turn_taking_loss = self.loss_fct(turn_taking_logits.view(-1, self.turn_taking_num_labels), turn_taking_labels.view(-1))
turn_taking_loss = self.loss_fct(
turn_taking_logits.view(-1, self.turn_taking_num_labels),
turn_taking_labels.view(-1),
)
loss = turn_taking_loss if loss is None else loss + turn_taking_loss
stats = {}
@ -5581,7 +5609,9 @@ class LLMVAD(nn.Module):
stats["turn_taking_loss"] = torch.clone(turn_taking_loss.detach())
with torch.no_grad():
turn_taking_preds = torch.argmax(turn_taking_logits, -1)
turn_taking_acc = compute_accuracy(turn_taking_preds, turn_taking_labels, ignore_label=-100)
turn_taking_acc = compute_accuracy(
turn_taking_preds, turn_taking_labels, ignore_label=-100
)
stats["turn_taking_acc"] = turn_taking_acc
if barge_in_labels is not None:
stats["barge_in_loss"] = torch.clone(barge_in_loss.detach())
@ -5637,7 +5667,13 @@ class LLMVAD(nn.Module):
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
stats = {"turn_taking_preds": [], "barge_in_preds": [], "turn_taking_labels": [], "barge_in_labels": [], 'task': task}
stats = {
"turn_taking_preds": [],
"barge_in_preds": [],
"turn_taking_labels": [],
"barge_in_labels": [],
"task": task,
}
with torch.cuda.amp.autocast(
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
):
@ -5668,18 +5704,25 @@ class LLMVAD(nn.Module):
if position_ids is None:
device = inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, \
_prepare_4d_causal_attention_mask_for_sdpa
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
if self.llm.config._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
attention_mask if (attention_mask is not None and 0 in attention_mask) else None
)
elif self.llm.config._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
@ -5734,19 +5777,41 @@ class LLMVAD(nn.Module):
for batch_idx in range(batch_size):
fbank_begin_index = fbank_beg[batch_idx, -1].item()
fbank_end_index = fbank_begin_index + fake_token_len[batch_idx, -1].item()
turn_taking_preds_last = turn_taking_preds[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist()
turn_taking_preds_last = (
turn_taking_preds[batch_idx, fbank_begin_index:fbank_end_index]
.cpu()
.numpy()
.tolist()
)
turn_taking_preds_res.append(turn_taking_preds_last)
# print(f"turn_taking_labels: {turn_taking_labels}")
turn_taking_labels_last = turn_taking_labels[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist()
turn_taking_labels_last = (
turn_taking_labels[batch_idx, fbank_begin_index:fbank_end_index]
.cpu()
.numpy()
.tolist()
)
turn_taking_labels_res.append(turn_taking_labels_last)
# print(f"turn_taking_preds: {turn_taking_preds_last}")
barge_in_preds_last = barge_in_preds[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist()
barge_in_preds_last = (
barge_in_preds[batch_idx, fbank_begin_index:fbank_end_index]
.cpu()
.numpy()
.tolist()
)
barge_in_preds_res.append(barge_in_preds_last)
# print(f"barge_in_labels: {barge_in_labels}")
barge_in_labels_last = barge_in_labels[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist()
barge_in_labels_last = (
barge_in_labels[batch_idx, fbank_begin_index:fbank_end_index]
.cpu()
.numpy()
.tolist()
)
barge_in_labels_res.append(barge_in_labels_last)
turn_taking_acc = compute_accuracy(turn_taking_preds, turn_taking_labels, ignore_label=-100)
turn_taking_acc = compute_accuracy(
turn_taking_preds, turn_taking_labels, ignore_label=-100
)
stats["turn_taking_acc"] = turn_taking_acc.item()
barge_in_acc = compute_accuracy(barge_in_preds, barge_in_labels, ignore_label=-100)
@ -5757,7 +5822,6 @@ class LLMVAD(nn.Module):
stats["barge_in_labels"].append(barge_in_labels_res)
return turn_taking_logits, barge_in_logits, meta_data, stats
def encode(self, speech, speech_lengths):
# audio encoder
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
@ -5789,20 +5853,19 @@ class LLMVAD(nn.Module):
}
if "task" in sample:
task = sample['task']
last_total_time = data[-1]['end_time'] - data[-1]['start_time']
if task == 'turn-taking':
true_time_span = data[-1]['turn-taking-gap_time-added']
task = sample["task"]
last_total_time = data[-1]["end_time"] - data[-1]["start_time"]
if task == "turn-taking":
true_time_span = data[-1]["turn-taking-gap_time-added"]
elif task == "barge-in":
true_time_span = last_total_time - data[-1]['barge-in-0']
true_time_span = last_total_time - data[-1]["barge-in-0"]
else:
raise ValueError("task must be turn-taking or barge-in")
contents["true_time_span"] = true_time_span
contents["last_total_time"] = last_total_time
contents['task'] = sample['task']
contents["task"] = sample["task"]
return contents
def data_template(self, data):
system, user, assistant = [], [], []
for i, item in enumerate(data):
@ -5828,7 +5891,6 @@ class LLMVAD(nn.Module):
return contents
def vad_data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
system = contents["system"]
@ -5852,13 +5914,9 @@ class LLMVAD(nn.Module):
if i == 0:
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
elif i == len(system) - 1:
source_input = (
f"<|im_start|>user\n{user_prompt}"
)
source_input = f"<|im_start|>user\n{user_prompt}"
else:
source_input = (
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
)
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
splits = pattern.split(source_input)
source_ids = []
@ -5941,18 +5999,22 @@ class LLMVAD(nn.Module):
if "true_time_span" in contents:
true_time_span = contents["true_time_span"]
last_time_span = contents["last_total_time"]
pos_vad = math.ceil(fake_token_len[-1] * (true_time_span/last_time_span))
pos_vad = math.ceil(fake_token_len[-1] * (true_time_span / last_time_span))
assert pos_vad <= fake_token_len[-1]
if pos_vad > 0:
last_vad[-pos_vad:] = [1] * pos_vad
turn_taking_labels[-fake_token_len[-1]:] = last_vad
barge_in_labels[-fake_token_len[-1]:] = last_vad
turn_taking_labels[-fake_token_len[-1] :] = last_vad
barge_in_labels[-fake_token_len[-1] :] = last_vad
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
turn_taking_labels = torch.tensor([turn_taking_labels], dtype=torch.int64) # [: self.max_token_length]
barge_in_labels = torch.tensor([barge_in_labels], dtype=torch.int64) # [: self.max_token_length]
turn_taking_labels = torch.tensor(
[turn_taking_labels], dtype=torch.int64
) # [: self.max_token_length]
barge_in_labels = torch.tensor(
[barge_in_labels], dtype=torch.int64
) # [: self.max_token_length]
# fbank = speech[0, :, :]
# fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
@ -6004,7 +6066,9 @@ class LLMVAD(nn.Module):
raise NotImplementedError("batch decoding is not implemented")
contents = self.vad_data_template(data_in[0])
output = self.vad_data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
output = self.vad_data_load_speech(
contents, tokenizer, frontend, meta_data=meta_data, **kwargs
)
batch = to_device(output, kwargs["device"])
# audio encoder

View File

@ -0,0 +1,438 @@
import asyncio
import json
import websockets
import time
import logging
import tracemalloc
import numpy as np
import argparse
import ssl
parser = argparse.ArgumentParser()
parser.add_argument(
"--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
)
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
parser.add_argument(
"--asr_model",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
help="model from modelscope",
)
parser.add_argument("--asr_model_revision", type=str, default="master", help="")
parser.add_argument(
"--asr_model_online",
type=str,
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
help="model from modelscope",
)
parser.add_argument("--asr_model_online_revision", type=str, default="master", help="")
parser.add_argument(
"--vad_model",
type=str,
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
help="model from modelscope",
)
parser.add_argument("--vad_model_revision", type=str, default="master", help="")
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
parser.add_argument(
"--certfile",
type=str,
default="../../ssl_key/server.crt",
required=False,
help="certfile for ssl",
)
parser.add_argument(
"--keyfile",
type=str,
default="../../ssl_key/server.key",
required=False,
help="keyfile for ssl",
)
args = parser.parse_args()
websocket_users = set()
print("model loading")
from funasr import AutoModel
# # asr
# model_asr = AutoModel(
# model=args.asr_model,
# model_revision=args.asr_model_revision,
# ngpu=args.ngpu,
# ncpu=args.ncpu,
# device=args.device,
# disable_pbar=False,
# disable_log=True,
# )
# vad
model_vad = AutoModel(
model=args.vad_model,
model_revision=args.vad_model_revision,
ngpu=args.ngpu,
ncpu=args.ncpu,
device=args.device,
disable_pbar=True,
disable_log=True,
# chunk_size=60,
)
# async def async_asr(websocket, audio_in):
# if len(audio_in) > 0:
# # print(len(audio_in))
# print(type(audio_in))
# rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
# print("offline_asr, ", rec_result)
#
#
# if len(rec_result["text"]) > 0:
# # print("offline", rec_result)
# mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
# message = json.dumps(
# {
# "mode": mode,
# "text": rec_result["text"],
# "wav_name": websocket.wav_name,
# "is_final": websocket.is_speaking,
# }
# )
# await websocket.send(message)
import os
# from install_model_requirements import install_requirements
#
# install_requirements()
# import librosa
# import base64
# import io
# import gradio as gr
# import re
import numpy as np
import torch
import torchaudio
from transformers import TextIteratorStreamer
from threading import Thread
import torch
import time
import traceback
# torch.backends.cuda.enable_mem_efficient_sdp(False)
# torch.backends.cuda.enable_flash_sdp(False)
from funasr import AutoModel
import re
import sys
from modelscope.hub.api import HubApi
api = HubApi()
if "key" in os.environ:
key = os.environ["key"]
api.login(key)
from modelscope.hub.snapshot_download import snapshot_download
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope"
device = "cuda:0"
all_file_paths = [
"/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope"
# "FunAudioLLM/Speech2Text_Align_V0712",
# "FunAudioLLM/Speech2Text_Align_V0718",
# "FunAudioLLM/Speech2Text_Align_V0628",
]
llm_kwargs = {"num_beams": 1, "do_sample": False}
ckpt_dir = all_file_paths[0]
model_llm = AutoModel(
model=ckpt_dir,
device=device,
fp16=False,
bf16=False,
llm_dtype="bf16",
max_length=1024,
llm_kwargs=llm_kwargs,
llm_conf={"init_param_path": llm_dir},
tokenizer_conf={"init_param_path": llm_dir},
audio_encoder=audio_encoder_dir,
)
model = model_llm.model
frontend = model_llm.kwargs["frontend"]
tokenizer = model_llm.kwargs["tokenizer"]
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
async def model_inference(
websocket,
audio_in,
his_state=None,
system_prompt="",
state=None,
turn_num=5,
history=None,
text_usr="",
):
if his_state is None:
his_state = model_dict
model = his_state["model"]
frontend = his_state["frontend"]
tokenizer = his_state["tokenizer"]
# print(f"text_inputs: {text_inputs}")
# print(f"audio_in: {audio_in}")
# print(f"websocket.llm_state: {websocket.llm_state}")
if websocket.llm_state is None:
websocket.llm_state = {"contents_i": []}
# print(f"history: {history}")
# if history is None:
# history = []
# audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/tmp/1.wav"
# user_prompt = f"<|startofspeech|>!{audio_in}<|endofspeech|>"
user_prompt = f"{text_usr}<|startofspeech|>!!<|endofspeech|>"
contents_i = websocket.llm_state["contents_i"]
# print(f"contents_i_0: {contents_i}")
if len(system_prompt) == 0:
system_prompt = "你是小夏,一位典型的温婉江南姑娘。你出生于杭州,声音清甜并有亲近感,会用简洁语言表达你的想法。你是用户的好朋友。你的回答将通过逼真的文字转语音技术读出。"
if len(contents_i) < 1:
contents_i.append({"role": "system", "content": system_prompt})
contents_i.append({"role": "user", "content": user_prompt, "audio": audio_in})
contents_i.append({"role": "assistant", "content": "target_out"})
if len(contents_i) > 2 * turn_num + 1:
print(
f"clip dialog pairs from: {len(contents_i)} to: {turn_num}, \ncontents_i_before_clip: {contents_i}"
)
contents_i = [{"role": "system", "content": system_prompt}] + contents_i[3:]
print(f"contents_i: {contents_i}")
inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
[contents_i], None, "test_demo", tokenizer, frontend, device=device
)
model_inputs = {}
model_inputs["inputs_embeds"] = inputs_embeds
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=1024)
thread = Thread(target=model.llm.generate, kwargs=generation_kwargs)
thread.start()
res = ""
beg_llm = time.time()
for new_text in streamer:
end_llm = time.time()
print(f"generated new text {new_text}, time: {end_llm-beg_llm:.2f}")
if len(new_text) > 0:
res += new_text.replace("<|im_end|>", "")
contents_i[-1]["content"] = res
websocket.llm_state["contents_i"] = contents_i
# history[-1][1] = res
mode = "2pass-online"
message = json.dumps(
{
"mode": mode,
"text": new_text,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
print(f"online: {message}")
await websocket.send(message)
mode = "2pass-offline"
message = json.dumps(
{
"mode": mode,
"text": res,
"wav_name": websocket.wav_name,
"is_final": websocket.is_speaking,
}
)
print(f"offline: {message}")
await websocket.send(message)
print("model loaded! only support one client at the same time now!!!!")
async def ws_reset(websocket):
print("ws reset now, total num is ", len(websocket_users))
websocket.status_dict_asr_online["cache"] = {}
websocket.status_dict_asr_online["is_final"] = True
websocket.status_dict_vad["cache"] = {}
websocket.status_dict_vad["is_final"] = True
websocket.status_dict_punc["cache"] = {}
await websocket.close()
async def clear_websocket():
for websocket in websocket_users:
await ws_reset(websocket)
websocket_users.clear()
async def ws_serve(websocket, path):
frames = []
frames_asr = []
frames_asr_online = []
global websocket_users
# await clear_websocket()
websocket_users.add(websocket)
websocket.status_dict_asr = {}
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
websocket.status_dict_vad = {"cache": {}, "is_final": False}
websocket.status_dict_punc = {"cache": {}}
websocket.chunk_interval = 10
websocket.vad_pre_idx = 0
speech_start = False
speech_end_i = -1
websocket.wav_name = "microphone"
websocket.mode = "2pass"
websocket.llm_state = None
print("new user connected", flush=True)
try:
async for message in websocket:
if isinstance(message, str):
messagejson = json.loads(message)
if "is_speaking" in messagejson:
websocket.is_speaking = messagejson["is_speaking"]
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
if "chunk_interval" in messagejson:
websocket.chunk_interval = messagejson["chunk_interval"]
if "wav_name" in messagejson:
websocket.wav_name = messagejson.get("wav_name")
if "chunk_size" in messagejson:
chunk_size = messagejson["chunk_size"]
if isinstance(chunk_size, str):
chunk_size = chunk_size.split(",")
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
if "encoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson[
"encoder_chunk_look_back"
]
if "decoder_chunk_look_back" in messagejson:
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson[
"decoder_chunk_look_back"
]
if "hotword" in messagejson:
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
if "mode" in messagejson:
websocket.mode = messagejson["mode"]
websocket.status_dict_vad["chunk_size"] = int(
websocket.status_dict_asr_online["chunk_size"][1] * 60 / websocket.chunk_interval
)
if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
if not isinstance(message, str):
frames.append(message)
duration_ms = len(message) // 32
websocket.vad_pre_idx += duration_ms
if speech_start:
frames_asr.append(message)
# vad online
try:
speech_start_i, speech_end_i = await async_vad(websocket, message)
except:
print("error in vad")
if speech_start_i != -1:
speech_start = True
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
frames_pre = frames[-beg_bias:]
frames_asr = []
frames_asr.extend(frames_pre)
# asr punc offline
if speech_end_i != -1 or not websocket.is_speaking:
# print("vad end point")
if websocket.mode == "2pass" or websocket.mode == "offline":
audio_in = b"".join(frames_asr)
try:
# await async_asr(websocket, audio_in)
await model_inference(websocket, audio_in)
except Exception as e:
print(f"{str(e)}, {traceback.format_exc()}")
frames_asr = []
speech_start = False
if not websocket.is_speaking:
websocket.vad_pre_idx = 0
frames = []
websocket.status_dict_vad["cache"] = {}
else:
frames = frames[-20:]
else:
print(f"message: {message}")
except websockets.ConnectionClosed:
print("ConnectionClosed...", websocket_users, flush=True)
await ws_reset(websocket)
websocket_users.remove(websocket)
except websockets.InvalidState:
print("InvalidState...")
except Exception as e:
print("Exception:", e)
async def async_vad(websocket, audio_in):
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
# print(segments_result)
speech_start = -1
speech_end = -1
if len(segments_result) == 0 or len(segments_result) > 1:
return speech_start, speech_end
if segments_result[0][0] != -1:
speech_start = segments_result[0][0]
if segments_result[0][1] != -1:
speech_end = segments_result[0][1]
return speech_start, speech_end
if len(args.certfile) > 0:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
ssl_cert = args.certfile
ssl_key = args.keyfile
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
start_server = websockets.serve(
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None, ssl=ssl_context
)
else:
start_server = websockets.serve(
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()