mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
docs
This commit is contained in:
parent
2d29a079ee
commit
70bdbabcb2
@ -14,8 +14,8 @@ from threading import Thread
|
||||
import torch
|
||||
import time
|
||||
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
# torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
# torch.backends.cuda.enable_flash_sdp(False)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from funasr.register import tables
|
||||
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
|
||||
import math
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "OpenAIDataset")
|
||||
class OpenAIDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
@ -776,6 +777,7 @@ class OpenAIDatasetMultiTurnCodec(torch.utils.data.Dataset):
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "OpenAIDatasetMultiTurnCodecMel")
|
||||
class OpenAIDatasetMultiTurnCodecMel(torch.utils.data.Dataset):
|
||||
"""
|
||||
@ -881,7 +883,8 @@ class OpenAIDatasetMultiTurnCodecMel(torch.utils.data.Dataset):
|
||||
audio,
|
||||
audio_len,
|
||||
input_mask,
|
||||
) = ([], [], [], [], [], [], [], [], [], [], [], [])
|
||||
input_mask_beg,
|
||||
) = ([], [], [], [], [], [], [], [], [], [], [], [], [])
|
||||
|
||||
multiturn_num = len(system)
|
||||
for i, (system_prompt, user_prompt, target_out) in enumerate(
|
||||
@ -998,6 +1001,19 @@ class OpenAIDatasetMultiTurnCodecMel(torch.utils.data.Dataset):
|
||||
fake_token_len += [fake_token_len_i]
|
||||
source_mask = [-100] * len(source_ids)
|
||||
|
||||
if i == 0:
|
||||
sys_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
|
||||
sys_prompt_len = self.tokenizer.encode(sys_prompt)
|
||||
input_mask_i = (
|
||||
[1] * len(sys_prompt_len) + [0] * len(source_ids) + [0] * len(target_ids)
|
||||
)
|
||||
else:
|
||||
input_mask_i = (
|
||||
[1] * len(input_ids) + [0] * len(source_ids) + [0] * len(target_ids)
|
||||
)
|
||||
input_mask_i = torch.tensor(input_mask_i, dtype=torch.int64)
|
||||
input_mask_beg.append(input_mask_i)
|
||||
|
||||
input_mask_i = [1] * len(input_ids) + [1] * len(source_ids) + [0] * len(target_ids)
|
||||
input_mask_i = torch.tensor(input_mask_i, dtype=torch.int64)
|
||||
input_mask.append(input_mask_i)
|
||||
@ -1043,6 +1059,8 @@ class OpenAIDatasetMultiTurnCodecMel(torch.utils.data.Dataset):
|
||||
output["audio_len"] = audio_len # torch.tensor(audio_len, dtype=torch.int32)
|
||||
if len(input_mask) > 0:
|
||||
output["input_mask"] = input_mask
|
||||
output["input_mask_beg"] = input_mask_beg
|
||||
|
||||
if key is not None:
|
||||
output["key"] = key
|
||||
break
|
||||
@ -1420,6 +1438,7 @@ class OpenAIDatasetMultiTurnCodecMel2(torch.utils.data.Dataset):
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@tables.register("dataset_classes", "OpenAIDatasetMultiTurnForFullDuplexVAD")
|
||||
class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
|
||||
"""
|
||||
@ -1427,14 +1446,14 @@ class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
index_ds: str = None,
|
||||
frontend=None,
|
||||
tokenizer=None,
|
||||
int_pad_value: int = -1,
|
||||
float_pad_value: float = 0.0,
|
||||
**kwargs,
|
||||
self,
|
||||
path,
|
||||
index_ds: str = None,
|
||||
frontend=None,
|
||||
tokenizer=None,
|
||||
int_pad_value: int = -1,
|
||||
float_pad_value: float = 0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
index_ds_class = tables.index_ds_classes.get(index_ds)
|
||||
@ -1525,7 +1544,7 @@ class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
for i, (system_prompt, user_prompt, target_out) in enumerate(
|
||||
zip(system, user, assistant)
|
||||
zip(system, user, assistant)
|
||||
):
|
||||
if len(input_ids) > self.max_token_length:
|
||||
logging.info(
|
||||
@ -1535,10 +1554,8 @@ class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
|
||||
|
||||
if i == 0:
|
||||
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
elif i == len(system)-1:
|
||||
source_input = (
|
||||
f"<|im_start|>user\n{user_prompt}"
|
||||
)
|
||||
elif i == len(system) - 1:
|
||||
source_input = f"<|im_start|>user\n{user_prompt}"
|
||||
else:
|
||||
source_input = (
|
||||
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
@ -1614,22 +1631,26 @@ class OpenAIDatasetMultiTurnForFullDuplexVAD(torch.utils.data.Dataset):
|
||||
turn_taking_labels = [-100] * len(labels)
|
||||
barge_in_labels = [-100] * len(labels)
|
||||
last_vad = [0] * fake_token_len[-1]
|
||||
pos_vad = math.ceil(fake_token_len[-1] * (true_time_span/last_time_span))
|
||||
pos_vad = math.ceil(fake_token_len[-1] * (true_time_span / last_time_span))
|
||||
assert pos_vad <= fake_token_len[-1]
|
||||
if pos_vad > 0:
|
||||
last_vad[-pos_vad:] = [1] * pos_vad
|
||||
|
||||
if task == "turn-taking":
|
||||
turn_taking_labels[-fake_token_len[-1]:] = last_vad
|
||||
turn_taking_labels[-fake_token_len[-1] :] = last_vad
|
||||
elif task == "barge-in":
|
||||
# print(f'barge-in: {last_vad}')
|
||||
barge_in_labels[-fake_token_len[-1]:] = last_vad
|
||||
barge_in_labels[-fake_token_len[-1] :] = last_vad
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
|
||||
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
|
||||
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
|
||||
turn_taking_labels = torch.tensor(turn_taking_labels, dtype=torch.int64) # [: self.max_token_length]
|
||||
barge_in_labels = torch.tensor(barge_in_labels, dtype=torch.int64) # [: self.max_token_length]
|
||||
turn_taking_labels = torch.tensor(
|
||||
turn_taking_labels, dtype=torch.int64
|
||||
) # [: self.max_token_length]
|
||||
barge_in_labels = torch.tensor(
|
||||
barge_in_labels, dtype=torch.int64
|
||||
) # [: self.max_token_length]
|
||||
|
||||
# fbank = speech[0, :, :]
|
||||
# fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
|
||||
|
||||
@ -1727,6 +1727,8 @@ class LLMASR4_extract_kv(nn.Module):
|
||||
)
|
||||
loss = model_outputs.loss
|
||||
|
||||
input_mask_beg = kwargs.get("input_mask_beg")
|
||||
input_mask_beg[input_mask_beg < 0] = 0
|
||||
input_mask = kwargs.get("input_mask")
|
||||
input_mask[input_mask < 0] = 0
|
||||
|
||||
@ -1737,9 +1739,10 @@ class LLMASR4_extract_kv(nn.Module):
|
||||
savemat(mat_file, {"kv_cache": hidden_states[0].cpu()})
|
||||
|
||||
for turn_id_cum in range(input_mask.shape[0]):
|
||||
beg = input_mask_beg[turn_id_cum].sum(-1)
|
||||
end = input_mask[turn_id_cum].sum(-1)
|
||||
uttid = f"{key}_assistant_{turn_id_cum:02d}"
|
||||
line = f"{uttid} {mat_file} {end}\n"
|
||||
line = f"{uttid} {mat_file} {beg} {end}\n"
|
||||
self.fo.write(line)
|
||||
self.fo.flush()
|
||||
|
||||
@ -5242,6 +5245,7 @@ class LLMASR7(nn.Module):
|
||||
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
||||
return top_ids
|
||||
|
||||
|
||||
@tables.register("model_classes", "LLMVAD")
|
||||
class LLMVAD(nn.Module):
|
||||
""" """
|
||||
@ -5381,14 +5385,23 @@ class LLMVAD(nn.Module):
|
||||
print("self.llm.config:", self.llm.config)
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
|
||||
from copy import deepcopy
|
||||
|
||||
self.task_decoder_layer_config = deepcopy(self.llm.config)
|
||||
self.task_decoder_layer_config.hidden_size = self.llm.config.hidden_size // 4
|
||||
self.task_decoder_layer_config.intermediate_size = self.llm.config.intermediate_size // 4
|
||||
self.task_decoder_layer_config.num_attention_heads = self.llm.config.num_attention_heads // 4
|
||||
self.task_decoder_layer_config.num_key_value_heads = self.llm.config.num_key_value_heads // 4
|
||||
self.task_decoder_layer_config.num_attention_heads = (
|
||||
self.llm.config.num_attention_heads // 4
|
||||
)
|
||||
self.task_decoder_layer_config.num_key_value_heads = (
|
||||
self.llm.config.num_key_value_heads // 4
|
||||
)
|
||||
print("self.task_decoder_layer_config:", self.task_decoder_layer_config)
|
||||
self.down_proj = nn.Linear(self.llm.config.hidden_size, self.task_decoder_layer_config.hidden_size, bias=False).to(dtype_map[self.llm_dtype])
|
||||
self.task_decoder_layer = Qwen2DecoderLayer(self.task_decoder_layer_config, self.llm.config.num_hidden_layers).to(dtype_map[self.llm_dtype])
|
||||
self.down_proj = nn.Linear(
|
||||
self.llm.config.hidden_size, self.task_decoder_layer_config.hidden_size, bias=False
|
||||
).to(dtype_map[self.llm_dtype])
|
||||
self.task_decoder_layer = Qwen2DecoderLayer(
|
||||
self.task_decoder_layer_config, self.llm.config.num_hidden_layers
|
||||
).to(dtype_map[self.llm_dtype])
|
||||
if getattr(self.llm.config, "classifier_dropout", None) is not None:
|
||||
classifier_dropout = self.llm.config.classifier_dropout
|
||||
elif getattr(self.llm.config, "hidden_dropout", None) is not None:
|
||||
@ -5398,9 +5411,12 @@ class LLMVAD(nn.Module):
|
||||
self.dropout = nn.Dropout(classifier_dropout)
|
||||
self.barge_in_num_labels = 2
|
||||
self.turn_taking_num_labels = 2
|
||||
self.barge_in_score = nn.Linear(self.task_decoder_layer_config.hidden_size, self.barge_in_num_labels).to(dtype_map[self.llm_dtype])
|
||||
self.turn_taking_score = nn.Linear(self.task_decoder_layer_config.hidden_size, self.turn_taking_num_labels).to(dtype_map[self.llm_dtype])
|
||||
|
||||
self.barge_in_score = nn.Linear(
|
||||
self.task_decoder_layer_config.hidden_size, self.barge_in_num_labels
|
||||
).to(dtype_map[self.llm_dtype])
|
||||
self.turn_taking_score = nn.Linear(
|
||||
self.task_decoder_layer_config.hidden_size, self.turn_taking_num_labels
|
||||
).to(dtype_map[self.llm_dtype])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -5503,17 +5519,24 @@ class LLMVAD(nn.Module):
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, \
|
||||
_prepare_4d_causal_attention_mask_for_sdpa
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
|
||||
if self.llm.config._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
causal_attention_mask = (
|
||||
attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
)
|
||||
elif self.llm.config._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
@ -5565,11 +5588,16 @@ class LLMVAD(nn.Module):
|
||||
loss = None
|
||||
if barge_in_labels is not None:
|
||||
barge_in_labels[barge_in_labels == -1] = -100
|
||||
barge_in_loss = self.loss_fct(barge_in_logits.view(-1, self.barge_in_num_labels), barge_in_labels.view(-1))
|
||||
barge_in_loss = self.loss_fct(
|
||||
barge_in_logits.view(-1, self.barge_in_num_labels), barge_in_labels.view(-1)
|
||||
)
|
||||
loss = barge_in_loss
|
||||
if turn_taking_labels is not None:
|
||||
turn_taking_labels[turn_taking_labels == -1] = -100
|
||||
turn_taking_loss = self.loss_fct(turn_taking_logits.view(-1, self.turn_taking_num_labels), turn_taking_labels.view(-1))
|
||||
turn_taking_loss = self.loss_fct(
|
||||
turn_taking_logits.view(-1, self.turn_taking_num_labels),
|
||||
turn_taking_labels.view(-1),
|
||||
)
|
||||
loss = turn_taking_loss if loss is None else loss + turn_taking_loss
|
||||
|
||||
stats = {}
|
||||
@ -5581,7 +5609,9 @@ class LLMVAD(nn.Module):
|
||||
stats["turn_taking_loss"] = torch.clone(turn_taking_loss.detach())
|
||||
with torch.no_grad():
|
||||
turn_taking_preds = torch.argmax(turn_taking_logits, -1)
|
||||
turn_taking_acc = compute_accuracy(turn_taking_preds, turn_taking_labels, ignore_label=-100)
|
||||
turn_taking_acc = compute_accuracy(
|
||||
turn_taking_preds, turn_taking_labels, ignore_label=-100
|
||||
)
|
||||
stats["turn_taking_acc"] = turn_taking_acc
|
||||
if barge_in_labels is not None:
|
||||
stats["barge_in_loss"] = torch.clone(barge_in_loss.detach())
|
||||
@ -5637,7 +5667,13 @@ class LLMVAD(nn.Module):
|
||||
llm_dtype = "fp16" if kwargs.get("fp16", False) else llm_dtype
|
||||
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
|
||||
|
||||
stats = {"turn_taking_preds": [], "barge_in_preds": [], "turn_taking_labels": [], "barge_in_labels": [], 'task': task}
|
||||
stats = {
|
||||
"turn_taking_preds": [],
|
||||
"barge_in_preds": [],
|
||||
"turn_taking_labels": [],
|
||||
"barge_in_labels": [],
|
||||
"task": task,
|
||||
}
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
|
||||
):
|
||||
@ -5668,18 +5704,25 @@ class LLMVAD(nn.Module):
|
||||
if position_ids is None:
|
||||
device = inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, \
|
||||
_prepare_4d_causal_attention_mask_for_sdpa
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
|
||||
if self.llm.config._attn_implementation == "flash_attention_2":
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
attention_mask = (
|
||||
attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
||||
)
|
||||
elif self.llm.config._attn_implementation == "sdpa" and not output_attentions:
|
||||
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
||||
# the manual implementation that requires a 4D causal mask in all cases.
|
||||
@ -5734,19 +5777,41 @@ class LLMVAD(nn.Module):
|
||||
for batch_idx in range(batch_size):
|
||||
fbank_begin_index = fbank_beg[batch_idx, -1].item()
|
||||
fbank_end_index = fbank_begin_index + fake_token_len[batch_idx, -1].item()
|
||||
turn_taking_preds_last = turn_taking_preds[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist()
|
||||
turn_taking_preds_last = (
|
||||
turn_taking_preds[batch_idx, fbank_begin_index:fbank_end_index]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist()
|
||||
)
|
||||
turn_taking_preds_res.append(turn_taking_preds_last)
|
||||
# print(f"turn_taking_labels: {turn_taking_labels}")
|
||||
turn_taking_labels_last = turn_taking_labels[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist()
|
||||
turn_taking_labels_last = (
|
||||
turn_taking_labels[batch_idx, fbank_begin_index:fbank_end_index]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist()
|
||||
)
|
||||
turn_taking_labels_res.append(turn_taking_labels_last)
|
||||
# print(f"turn_taking_preds: {turn_taking_preds_last}")
|
||||
barge_in_preds_last = barge_in_preds[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist()
|
||||
barge_in_preds_last = (
|
||||
barge_in_preds[batch_idx, fbank_begin_index:fbank_end_index]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist()
|
||||
)
|
||||
barge_in_preds_res.append(barge_in_preds_last)
|
||||
# print(f"barge_in_labels: {barge_in_labels}")
|
||||
barge_in_labels_last = barge_in_labels[batch_idx, fbank_begin_index:fbank_end_index].cpu().numpy().tolist()
|
||||
barge_in_labels_last = (
|
||||
barge_in_labels[batch_idx, fbank_begin_index:fbank_end_index]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.tolist()
|
||||
)
|
||||
barge_in_labels_res.append(barge_in_labels_last)
|
||||
|
||||
turn_taking_acc = compute_accuracy(turn_taking_preds, turn_taking_labels, ignore_label=-100)
|
||||
turn_taking_acc = compute_accuracy(
|
||||
turn_taking_preds, turn_taking_labels, ignore_label=-100
|
||||
)
|
||||
stats["turn_taking_acc"] = turn_taking_acc.item()
|
||||
|
||||
barge_in_acc = compute_accuracy(barge_in_preds, barge_in_labels, ignore_label=-100)
|
||||
@ -5757,7 +5822,6 @@ class LLMVAD(nn.Module):
|
||||
stats["barge_in_labels"].append(barge_in_labels_res)
|
||||
return turn_taking_logits, barge_in_logits, meta_data, stats
|
||||
|
||||
|
||||
def encode(self, speech, speech_lengths):
|
||||
# audio encoder
|
||||
encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
|
||||
@ -5789,20 +5853,19 @@ class LLMVAD(nn.Module):
|
||||
}
|
||||
|
||||
if "task" in sample:
|
||||
task = sample['task']
|
||||
last_total_time = data[-1]['end_time'] - data[-1]['start_time']
|
||||
if task == 'turn-taking':
|
||||
true_time_span = data[-1]['turn-taking-gap_time-added']
|
||||
task = sample["task"]
|
||||
last_total_time = data[-1]["end_time"] - data[-1]["start_time"]
|
||||
if task == "turn-taking":
|
||||
true_time_span = data[-1]["turn-taking-gap_time-added"]
|
||||
elif task == "barge-in":
|
||||
true_time_span = last_total_time - data[-1]['barge-in-0']
|
||||
true_time_span = last_total_time - data[-1]["barge-in-0"]
|
||||
else:
|
||||
raise ValueError("task must be turn-taking or barge-in")
|
||||
contents["true_time_span"] = true_time_span
|
||||
contents["last_total_time"] = last_total_time
|
||||
contents['task'] = sample['task']
|
||||
contents["task"] = sample["task"]
|
||||
return contents
|
||||
|
||||
|
||||
def data_template(self, data):
|
||||
system, user, assistant = [], [], []
|
||||
for i, item in enumerate(data):
|
||||
@ -5828,7 +5891,6 @@ class LLMVAD(nn.Module):
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def vad_data_load_speech(self, contents: dict, tokenizer, frontend, meta_data={}, **kwargs):
|
||||
|
||||
system = contents["system"]
|
||||
@ -5852,13 +5914,9 @@ class LLMVAD(nn.Module):
|
||||
if i == 0:
|
||||
source_input = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
elif i == len(system) - 1:
|
||||
source_input = (
|
||||
f"<|im_start|>user\n{user_prompt}"
|
||||
)
|
||||
source_input = f"<|im_start|>user\n{user_prompt}"
|
||||
else:
|
||||
source_input = (
|
||||
f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
source_input = f"<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
splits = pattern.split(source_input)
|
||||
source_ids = []
|
||||
@ -5941,18 +5999,22 @@ class LLMVAD(nn.Module):
|
||||
if "true_time_span" in contents:
|
||||
true_time_span = contents["true_time_span"]
|
||||
last_time_span = contents["last_total_time"]
|
||||
pos_vad = math.ceil(fake_token_len[-1] * (true_time_span/last_time_span))
|
||||
pos_vad = math.ceil(fake_token_len[-1] * (true_time_span / last_time_span))
|
||||
assert pos_vad <= fake_token_len[-1]
|
||||
if pos_vad > 0:
|
||||
last_vad[-pos_vad:] = [1] * pos_vad
|
||||
turn_taking_labels[-fake_token_len[-1]:] = last_vad
|
||||
barge_in_labels[-fake_token_len[-1]:] = last_vad
|
||||
turn_taking_labels[-fake_token_len[-1] :] = last_vad
|
||||
barge_in_labels[-fake_token_len[-1] :] = last_vad
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
|
||||
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
|
||||
labels = torch.tensor(labels, dtype=torch.int64) # [: self.max_token_length]
|
||||
turn_taking_labels = torch.tensor([turn_taking_labels], dtype=torch.int64) # [: self.max_token_length]
|
||||
barge_in_labels = torch.tensor([barge_in_labels], dtype=torch.int64) # [: self.max_token_length]
|
||||
turn_taking_labels = torch.tensor(
|
||||
[turn_taking_labels], dtype=torch.int64
|
||||
) # [: self.max_token_length]
|
||||
barge_in_labels = torch.tensor(
|
||||
[barge_in_labels], dtype=torch.int64
|
||||
) # [: self.max_token_length]
|
||||
|
||||
# fbank = speech[0, :, :]
|
||||
# fbank_lens = torch.tensor(fbank_lens, dtype=torch.int32)
|
||||
@ -6004,7 +6066,9 @@ class LLMVAD(nn.Module):
|
||||
raise NotImplementedError("batch decoding is not implemented")
|
||||
|
||||
contents = self.vad_data_template(data_in[0])
|
||||
output = self.vad_data_load_speech(contents, tokenizer, frontend, meta_data=meta_data, **kwargs)
|
||||
output = self.vad_data_load_speech(
|
||||
contents, tokenizer, frontend, meta_data=meta_data, **kwargs
|
||||
)
|
||||
batch = to_device(output, kwargs["device"])
|
||||
|
||||
# audio encoder
|
||||
|
||||
438
runtime/python/websocket/funasr_wss_server_llm.py
Normal file
438
runtime/python/websocket/funasr_wss_server_llm.py
Normal file
@ -0,0 +1,438 @@
|
||||
import asyncio
|
||||
import json
|
||||
import websockets
|
||||
import time
|
||||
import logging
|
||||
import tracemalloc
|
||||
import numpy as np
|
||||
import argparse
|
||||
import ssl
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--host", type=str, default="127.0.0.1", required=False, help="host ip, localhost, 0.0.0.0"
|
||||
)
|
||||
parser.add_argument("--port", type=int, default=10095, required=False, help="grpc server port")
|
||||
parser.add_argument(
|
||||
"--asr_model",
|
||||
type=str,
|
||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
help="model from modelscope",
|
||||
)
|
||||
parser.add_argument("--asr_model_revision", type=str, default="master", help="")
|
||||
parser.add_argument(
|
||||
"--asr_model_online",
|
||||
type=str,
|
||||
default="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
|
||||
help="model from modelscope",
|
||||
)
|
||||
parser.add_argument("--asr_model_online_revision", type=str, default="master", help="")
|
||||
parser.add_argument(
|
||||
"--vad_model",
|
||||
type=str,
|
||||
default="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
help="model from modelscope",
|
||||
)
|
||||
parser.add_argument("--vad_model_revision", type=str, default="master", help="")
|
||||
parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
|
||||
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
|
||||
parser.add_argument(
|
||||
"--certfile",
|
||||
type=str,
|
||||
default="../../ssl_key/server.crt",
|
||||
required=False,
|
||||
help="certfile for ssl",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
default="../../ssl_key/server.key",
|
||||
required=False,
|
||||
help="keyfile for ssl",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
websocket_users = set()
|
||||
|
||||
print("model loading")
|
||||
from funasr import AutoModel
|
||||
|
||||
# # asr
|
||||
# model_asr = AutoModel(
|
||||
# model=args.asr_model,
|
||||
# model_revision=args.asr_model_revision,
|
||||
# ngpu=args.ngpu,
|
||||
# ncpu=args.ncpu,
|
||||
# device=args.device,
|
||||
# disable_pbar=False,
|
||||
# disable_log=True,
|
||||
# )
|
||||
|
||||
# vad
|
||||
model_vad = AutoModel(
|
||||
model=args.vad_model,
|
||||
model_revision=args.vad_model_revision,
|
||||
ngpu=args.ngpu,
|
||||
ncpu=args.ncpu,
|
||||
device=args.device,
|
||||
disable_pbar=True,
|
||||
disable_log=True,
|
||||
# chunk_size=60,
|
||||
)
|
||||
|
||||
|
||||
# async def async_asr(websocket, audio_in):
|
||||
# if len(audio_in) > 0:
|
||||
# # print(len(audio_in))
|
||||
# print(type(audio_in))
|
||||
# rec_result = model_asr.generate(input=audio_in, **websocket.status_dict_asr)[0]
|
||||
# print("offline_asr, ", rec_result)
|
||||
#
|
||||
#
|
||||
# if len(rec_result["text"]) > 0:
|
||||
# # print("offline", rec_result)
|
||||
# mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
|
||||
# message = json.dumps(
|
||||
# {
|
||||
# "mode": mode,
|
||||
# "text": rec_result["text"],
|
||||
# "wav_name": websocket.wav_name,
|
||||
# "is_final": websocket.is_speaking,
|
||||
# }
|
||||
# )
|
||||
# await websocket.send(message)
|
||||
|
||||
import os
|
||||
|
||||
# from install_model_requirements import install_requirements
|
||||
#
|
||||
# install_requirements()
|
||||
|
||||
# import librosa
|
||||
# import base64
|
||||
# import io
|
||||
# import gradio as gr
|
||||
# import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from transformers import TextIteratorStreamer
|
||||
from threading import Thread
|
||||
import torch
|
||||
import time
|
||||
import traceback
|
||||
|
||||
# torch.backends.cuda.enable_mem_efficient_sdp(False)
|
||||
# torch.backends.cuda.enable_flash_sdp(False)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
import re
|
||||
|
||||
import sys
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
|
||||
api = HubApi()
|
||||
if "key" in os.environ:
|
||||
key = os.environ["key"]
|
||||
api.login(key)
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
# os.environ["MODELSCOPE_CACHE"] = "/nfs/zhifu.gzf/modelscope"
|
||||
# llm_dir = snapshot_download('qwen/Qwen2-7B-Instruct', cache_dir=None, revision='master')
|
||||
# audio_encoder_dir = snapshot_download('iic/SenseVoice', cache_dir=None, revision='master')
|
||||
|
||||
llm_dir = "/cpfs_speech/zhifu.gzf/init_model/qwen/Qwen2-7B-Instruct"
|
||||
audio_encoder_dir = "/nfs/zhifu.gzf/init_model/SenseVoiceLargeModelscope"
|
||||
|
||||
device = "cuda:0"
|
||||
|
||||
all_file_paths = [
|
||||
"/nfs/zhifu.gzf/init_model/Speech2Text_Align_V0712_modelscope"
|
||||
# "FunAudioLLM/Speech2Text_Align_V0712",
|
||||
# "FunAudioLLM/Speech2Text_Align_V0718",
|
||||
# "FunAudioLLM/Speech2Text_Align_V0628",
|
||||
]
|
||||
|
||||
llm_kwargs = {"num_beams": 1, "do_sample": False}
|
||||
|
||||
ckpt_dir = all_file_paths[0]
|
||||
|
||||
model_llm = AutoModel(
|
||||
model=ckpt_dir,
|
||||
device=device,
|
||||
fp16=False,
|
||||
bf16=False,
|
||||
llm_dtype="bf16",
|
||||
max_length=1024,
|
||||
llm_kwargs=llm_kwargs,
|
||||
llm_conf={"init_param_path": llm_dir},
|
||||
tokenizer_conf={"init_param_path": llm_dir},
|
||||
audio_encoder=audio_encoder_dir,
|
||||
)
|
||||
|
||||
model = model_llm.model
|
||||
frontend = model_llm.kwargs["frontend"]
|
||||
tokenizer = model_llm.kwargs["tokenizer"]
|
||||
|
||||
model_dict = {"model": model, "frontend": frontend, "tokenizer": tokenizer}
|
||||
|
||||
|
||||
async def model_inference(
|
||||
websocket,
|
||||
audio_in,
|
||||
his_state=None,
|
||||
system_prompt="",
|
||||
state=None,
|
||||
turn_num=5,
|
||||
history=None,
|
||||
text_usr="",
|
||||
):
|
||||
if his_state is None:
|
||||
his_state = model_dict
|
||||
model = his_state["model"]
|
||||
frontend = his_state["frontend"]
|
||||
tokenizer = his_state["tokenizer"]
|
||||
# print(f"text_inputs: {text_inputs}")
|
||||
# print(f"audio_in: {audio_in}")
|
||||
# print(f"websocket.llm_state: {websocket.llm_state}")
|
||||
|
||||
if websocket.llm_state is None:
|
||||
websocket.llm_state = {"contents_i": []}
|
||||
# print(f"history: {history}")
|
||||
# if history is None:
|
||||
# history = []
|
||||
|
||||
# audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/tmp/1.wav"
|
||||
# user_prompt = f"<|startofspeech|>!{audio_in}<|endofspeech|>"
|
||||
user_prompt = f"{text_usr}<|startofspeech|>!!<|endofspeech|>"
|
||||
|
||||
contents_i = websocket.llm_state["contents_i"]
|
||||
# print(f"contents_i_0: {contents_i}")
|
||||
if len(system_prompt) == 0:
|
||||
system_prompt = "你是小夏,一位典型的温婉江南姑娘。你出生于杭州,声音清甜并有亲近感,会用简洁语言表达你的想法。你是用户的好朋友。你的回答将通过逼真的文字转语音技术读出。"
|
||||
|
||||
if len(contents_i) < 1:
|
||||
contents_i.append({"role": "system", "content": system_prompt})
|
||||
contents_i.append({"role": "user", "content": user_prompt, "audio": audio_in})
|
||||
contents_i.append({"role": "assistant", "content": "target_out"})
|
||||
if len(contents_i) > 2 * turn_num + 1:
|
||||
print(
|
||||
f"clip dialog pairs from: {len(contents_i)} to: {turn_num}, \ncontents_i_before_clip: {contents_i}"
|
||||
)
|
||||
contents_i = [{"role": "system", "content": system_prompt}] + contents_i[3:]
|
||||
|
||||
print(f"contents_i: {contents_i}")
|
||||
|
||||
inputs_embeds, contents, batch, source_ids, meta_data = model.inference_prepare(
|
||||
[contents_i], None, "test_demo", tokenizer, frontend, device=device
|
||||
)
|
||||
model_inputs = {}
|
||||
model_inputs["inputs_embeds"] = inputs_embeds
|
||||
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
|
||||
generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=1024)
|
||||
thread = Thread(target=model.llm.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
res = ""
|
||||
beg_llm = time.time()
|
||||
for new_text in streamer:
|
||||
end_llm = time.time()
|
||||
print(f"generated new text: {new_text}, time: {end_llm-beg_llm:.2f}")
|
||||
|
||||
if len(new_text) > 0:
|
||||
res += new_text.replace("<|im_end|>", "")
|
||||
contents_i[-1]["content"] = res
|
||||
websocket.llm_state["contents_i"] = contents_i
|
||||
# history[-1][1] = res
|
||||
|
||||
mode = "2pass-online"
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": new_text,
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
}
|
||||
)
|
||||
print(f"online: {message}")
|
||||
await websocket.send(message)
|
||||
|
||||
mode = "2pass-offline"
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": mode,
|
||||
"text": res,
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
}
|
||||
)
|
||||
print(f"offline: {message}")
|
||||
await websocket.send(message)
|
||||
|
||||
|
||||
print("model loaded! only support one client at the same time now!!!!")
|
||||
|
||||
|
||||
async def ws_reset(websocket):
|
||||
print("ws reset now, total num is ", len(websocket_users))
|
||||
|
||||
websocket.status_dict_asr_online["cache"] = {}
|
||||
websocket.status_dict_asr_online["is_final"] = True
|
||||
websocket.status_dict_vad["cache"] = {}
|
||||
websocket.status_dict_vad["is_final"] = True
|
||||
websocket.status_dict_punc["cache"] = {}
|
||||
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def clear_websocket():
|
||||
for websocket in websocket_users:
|
||||
await ws_reset(websocket)
|
||||
websocket_users.clear()
|
||||
|
||||
|
||||
async def ws_serve(websocket, path):
|
||||
frames = []
|
||||
frames_asr = []
|
||||
frames_asr_online = []
|
||||
global websocket_users
|
||||
# await clear_websocket()
|
||||
websocket_users.add(websocket)
|
||||
websocket.status_dict_asr = {}
|
||||
websocket.status_dict_asr_online = {"cache": {}, "is_final": False}
|
||||
websocket.status_dict_vad = {"cache": {}, "is_final": False}
|
||||
websocket.status_dict_punc = {"cache": {}}
|
||||
websocket.chunk_interval = 10
|
||||
websocket.vad_pre_idx = 0
|
||||
speech_start = False
|
||||
speech_end_i = -1
|
||||
websocket.wav_name = "microphone"
|
||||
websocket.mode = "2pass"
|
||||
websocket.llm_state = None
|
||||
print("new user connected", flush=True)
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
if isinstance(message, str):
|
||||
messagejson = json.loads(message)
|
||||
|
||||
if "is_speaking" in messagejson:
|
||||
websocket.is_speaking = messagejson["is_speaking"]
|
||||
websocket.status_dict_asr_online["is_final"] = not websocket.is_speaking
|
||||
if "chunk_interval" in messagejson:
|
||||
websocket.chunk_interval = messagejson["chunk_interval"]
|
||||
if "wav_name" in messagejson:
|
||||
websocket.wav_name = messagejson.get("wav_name")
|
||||
if "chunk_size" in messagejson:
|
||||
chunk_size = messagejson["chunk_size"]
|
||||
if isinstance(chunk_size, str):
|
||||
chunk_size = chunk_size.split(",")
|
||||
websocket.status_dict_asr_online["chunk_size"] = [int(x) for x in chunk_size]
|
||||
if "encoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["encoder_chunk_look_back"] = messagejson[
|
||||
"encoder_chunk_look_back"
|
||||
]
|
||||
if "decoder_chunk_look_back" in messagejson:
|
||||
websocket.status_dict_asr_online["decoder_chunk_look_back"] = messagejson[
|
||||
"decoder_chunk_look_back"
|
||||
]
|
||||
if "hotword" in messagejson:
|
||||
websocket.status_dict_asr["hotword"] = messagejson["hotwords"]
|
||||
if "mode" in messagejson:
|
||||
websocket.mode = messagejson["mode"]
|
||||
|
||||
websocket.status_dict_vad["chunk_size"] = int(
|
||||
websocket.status_dict_asr_online["chunk_size"][1] * 60 / websocket.chunk_interval
|
||||
)
|
||||
if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
|
||||
if not isinstance(message, str):
|
||||
frames.append(message)
|
||||
duration_ms = len(message) // 32
|
||||
websocket.vad_pre_idx += duration_ms
|
||||
|
||||
if speech_start:
|
||||
frames_asr.append(message)
|
||||
# vad online
|
||||
try:
|
||||
speech_start_i, speech_end_i = await async_vad(websocket, message)
|
||||
except:
|
||||
print("error in vad")
|
||||
if speech_start_i != -1:
|
||||
speech_start = True
|
||||
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
|
||||
frames_pre = frames[-beg_bias:]
|
||||
frames_asr = []
|
||||
frames_asr.extend(frames_pre)
|
||||
# asr punc offline
|
||||
if speech_end_i != -1 or not websocket.is_speaking:
|
||||
# print("vad end point")
|
||||
if websocket.mode == "2pass" or websocket.mode == "offline":
|
||||
audio_in = b"".join(frames_asr)
|
||||
try:
|
||||
# await async_asr(websocket, audio_in)
|
||||
await model_inference(websocket, audio_in)
|
||||
except Exception as e:
|
||||
print(f"{str(e)}, {traceback.format_exc()}")
|
||||
frames_asr = []
|
||||
speech_start = False
|
||||
|
||||
if not websocket.is_speaking:
|
||||
websocket.vad_pre_idx = 0
|
||||
frames = []
|
||||
websocket.status_dict_vad["cache"] = {}
|
||||
else:
|
||||
frames = frames[-20:]
|
||||
else:
|
||||
print(f"message: {message}")
|
||||
except websockets.ConnectionClosed:
|
||||
print("ConnectionClosed...", websocket_users, flush=True)
|
||||
await ws_reset(websocket)
|
||||
websocket_users.remove(websocket)
|
||||
except websockets.InvalidState:
|
||||
print("InvalidState...")
|
||||
except Exception as e:
|
||||
print("Exception:", e)
|
||||
|
||||
|
||||
async def async_vad(websocket, audio_in):
|
||||
segments_result = model_vad.generate(input=audio_in, **websocket.status_dict_vad)[0]["value"]
|
||||
# print(segments_result)
|
||||
|
||||
speech_start = -1
|
||||
speech_end = -1
|
||||
|
||||
if len(segments_result) == 0 or len(segments_result) > 1:
|
||||
return speech_start, speech_end
|
||||
if segments_result[0][0] != -1:
|
||||
speech_start = segments_result[0][0]
|
||||
if segments_result[0][1] != -1:
|
||||
speech_end = segments_result[0][1]
|
||||
return speech_start, speech_end
|
||||
|
||||
|
||||
if len(args.certfile) > 0:
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
|
||||
# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
|
||||
ssl_cert = args.certfile
|
||||
ssl_key = args.keyfile
|
||||
|
||||
ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
|
||||
start_server = websockets.serve(
|
||||
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None, ssl=ssl_context
|
||||
)
|
||||
else:
|
||||
start_server = websockets.serve(
|
||||
ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None
|
||||
)
|
||||
asyncio.get_event_loop().run_until_complete(start_server)
|
||||
asyncio.get_event_loop().run_forever()
|
||||
Loading…
Reference in New Issue
Block a user