mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Merge branch 'dev_gzf_deepspeed' of gitlab.alibaba-inc.com:zhifu.gzf/FunASR into dev_gzf_deepspeed
merge
This commit is contained in:
commit
e9fb52d788
@ -2871,9 +2871,6 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
|
||||
# tts related inference, require the kv cache of llm last layer for only the current inputs
|
||||
# TODO: select kv cache of the current turn inputs
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds,
|
||||
@ -2918,8 +2915,11 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
):
|
||||
assert llm_cur_kv_cache is not None
|
||||
set_all_random_seed(rand_seed)
|
||||
speech_tokens, mel, wav = self.generate_speech(
|
||||
response, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype]
|
||||
# speech_tokens, mel, wav = self.generate_speech(
|
||||
# response, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype]
|
||||
# )
|
||||
speech_tokens, mel, wav = self.simulate_streaming_generate_speech(
|
||||
target_ids, llm_cur_kv_cache, llm_cur_kv_cache_len, dtype_map[tts_dtype], tokenizer
|
||||
)
|
||||
self.write_mel_wav(kwargs.get("output_dir"), mel, wav, key[0])
|
||||
|
||||
@ -2942,12 +2942,142 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
None,
|
||||
outside_prompt=llm_cur_kv_cache,
|
||||
outside_prompt_lengths=llm_cur_kv_cache_len,
|
||||
sampling="threshold_1e-6",
|
||||
)
|
||||
# vocoder forward
|
||||
wav = self.vocoder.inference(mel_feats.transpose(1, 2))
|
||||
|
||||
return speech_tokens, mel_feats, wav
|
||||
|
||||
def split_characters_and_words(self, input_string):
|
||||
# 定义正则表达式模式
|
||||
pattern = r'[\u4e00-\u9fff]|[\w]+|[^\w\s]'
|
||||
# 使用 re.findall 找到所有匹配的字符和单词
|
||||
results = re.findall(pattern, input_string)
|
||||
return results
|
||||
|
||||
def tts_tokenizer_warpper(self, text):
|
||||
text_token = self.tts_text_tokenizer.text2tokens(text)
|
||||
# remove the added pouc by ttsfrd.
|
||||
if text[-1] != "。" and text_token[-1] == 1542:
|
||||
text_token = text_token[:-1]
|
||||
return text_token
|
||||
|
||||
def generate_speech_one_step(
|
||||
self,
|
||||
text: str, last_t_size,
|
||||
llm_cur_kv_cache, llm_cur_kv_cache_len,
|
||||
prompt_token, prompt_audio, tts_text_chunk_size,
|
||||
chunk_idx, is_last, para_len=30,
|
||||
):
|
||||
device = llm_cur_kv_cache.device
|
||||
pounc = ['。', '?', '!', ';', ':', '.', '?', '!', ';', '\n']
|
||||
|
||||
# remove duplicated pounctuations
|
||||
normed_text = []
|
||||
for i, c in enumerate(text):
|
||||
if i > 0 and text[i-1] in pounc and text[i] in pounc:
|
||||
continue
|
||||
normed_text.append(c)
|
||||
text = "".join(normed_text)
|
||||
|
||||
cur_token, feat, wav = None, None, None
|
||||
_text = f"<|endofprompt|><|sil|>{text}" + ("<|sil|>" if is_last else "")
|
||||
text_token = self.tts_tokenizer_warpper(_text)
|
||||
t_size = len(text_token)
|
||||
if (t_size - last_t_size) >= tts_text_chunk_size or is_last:
|
||||
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.long, device=device)
|
||||
cur_token, feat = self.tts_model.streaming_one_step(
|
||||
text_token, text_token_len,
|
||||
xvec=None, xvec_lengths=None,
|
||||
prompt_dict={
|
||||
"prompt_token": prompt_token,
|
||||
"prompt_audio": prompt_audio,
|
||||
},
|
||||
outside_prompt=llm_cur_kv_cache,
|
||||
outside_prompt_lengths=llm_cur_kv_cache_len,
|
||||
sampling="threshold_1e-6",
|
||||
chunk_idx=chunk_idx,
|
||||
)
|
||||
if cur_token is not None and cur_token.shape[1] > 0 and feat.shape[2] > 0:
|
||||
# process first package, token in B,T,D, feat in B,F,T
|
||||
if prompt_token[0] is None:
|
||||
prompt_token = [cur_token, torch.tensor([cur_token.shape[1]], dtype=torch.long, device=device)]
|
||||
prompt_audio = [feat.transpose(1, 2), torch.tensor([feat.shape[2]], dtype=torch.long, device=device)]
|
||||
else:
|
||||
prompt_token[1] = prompt_token[1] + cur_token.shape[1]
|
||||
prompt_token[0] = torch.concat([prompt_token[0], cur_token], dim=1)
|
||||
prompt_audio[1] = prompt_audio[1] + feat.shape[2]
|
||||
prompt_audio[0] = torch.concat([prompt_audio[0], feat.transpose(1, 2)], dim=1)
|
||||
wav = self.vocoder.inference(feat.transpose(1, 2))
|
||||
chunk_idx += 1
|
||||
else:
|
||||
cur_token, feat, wav = None, None, None
|
||||
|
||||
# post process
|
||||
last_t_size = t_size
|
||||
# restart a new paragraph
|
||||
# char_words = self.split_characters_and_words(text)
|
||||
# if len(char_words) > para_len:
|
||||
# # find the last pounc to split paragraph
|
||||
# idx = -1
|
||||
# for i in range(len(char_words)-1, -1, -1):
|
||||
# if char_words[i] in pounc:
|
||||
# idx = i
|
||||
# break
|
||||
# if idx > 0:
|
||||
# text = text[idx+1:]
|
||||
# last_t_size = len(self.tts_tokenizer_warpper(text))
|
||||
|
||||
return ((cur_token, feat, wav),
|
||||
(text, last_t_size, prompt_token, prompt_audio, chunk_idx))
|
||||
|
||||
def simulate_streaming_generate_speech(self, preds, llm_cur_kv_cache, llm_cur_kv_cache_len, llm_dtype, llm_tokenizer):
|
||||
# self.tts_text_tokenizer = self.tts_text_tokenizer
|
||||
self.vocoder.to(llm_dtype)
|
||||
self.tts_model.to(llm_dtype)
|
||||
llm_token_num_per_call = 3
|
||||
text_chunk_size = 8
|
||||
given_rtf = 0.5
|
||||
|
||||
token_list, feat_list, wav_list = [], [], []
|
||||
prompt_token, prompt_audio = [None, None], [None, None]
|
||||
new_text, last_t_size, chunk_idx = "", 0, 0
|
||||
st, count = 0, 0
|
||||
while st < preds.shape[1]:
|
||||
chunk_size = int(llm_token_num_per_call / (given_rtf ** min(count, 2)))
|
||||
_resp = llm_tokenizer.batch_decode(
|
||||
preds[:, st:st + chunk_size],
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=True,
|
||||
)[0]
|
||||
is_last = (st + chunk_size >= preds.shape[1])
|
||||
|
||||
new_text = new_text + _resp
|
||||
rt_value, states = self.generate_speech_one_step(
|
||||
new_text, last_t_size,
|
||||
llm_cur_kv_cache, llm_cur_kv_cache_len,
|
||||
prompt_token, prompt_audio,
|
||||
text_chunk_size,
|
||||
chunk_idx, is_last,
|
||||
)
|
||||
cur_token, feat, wav = rt_value
|
||||
new_text, last_t_size, prompt_token, prompt_audio, chunk_idx = states
|
||||
# save results
|
||||
if cur_token is not None and feat is not None and wav is not None:
|
||||
token_list.append(cur_token)
|
||||
feat_list.append(feat)
|
||||
wav_list.append(wav)
|
||||
|
||||
st += chunk_size
|
||||
count += 1
|
||||
|
||||
speech_tokens = torch.cat(token_list, dim=1)
|
||||
mel_feats = torch.cat(feat_list, dim=2)
|
||||
wav = torch.cat(wav_list, dim=1)
|
||||
return speech_tokens, mel_feats, wav
|
||||
|
||||
def write_mel_wav(self, output_dir, feat, wav, key):
|
||||
out_dir = os.path.join(output_dir, "1best_recog", "mels")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
@ -741,7 +741,7 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
device = text.device
|
||||
use_causal_prob = kwargs.get("use_causal_prob", 1.0)
|
||||
# streaming related config
|
||||
chunk_size = kwargs.get("streaming_chunk_size", 1)
|
||||
chunk_size = kwargs.get("streaming_chunk_size", 4)
|
||||
chunk_size_maxium = kwargs.get("chunk_size_maxium", 16)
|
||||
try:
|
||||
lookahead_size = self.am_model.encoder.pre_lookahead_len
|
||||
@ -899,3 +899,129 @@ class UCTDXvecSlotModel(UpsampleCtcTokenDiffModel):
|
||||
break
|
||||
|
||||
return tokens, prompt_audio[0].transpose(1, 2)
|
||||
|
||||
def streaming_one_step(
|
||||
self, text: torch.Tensor, text_lengths: torch.Tensor,
|
||||
xvec: Optional[torch.Tensor] = None, xvec_lengths: Optional[torch.Tensor] = None,
|
||||
chunk_idx=0,
|
||||
**kwargs
|
||||
):
|
||||
device = text.device
|
||||
use_causal_prob = kwargs.get("use_causal_prob", 1.0)
|
||||
# streaming related config
|
||||
chunk_size = kwargs.get("streaming_chunk_size", 4)
|
||||
chunk_size_maxium = kwargs.get("chunk_size_maxium", 16)
|
||||
lookahead_size = self.am_model.encoder.pre_lookahead_len
|
||||
hint_once(f"chunk_size={chunk_size}, chunk_size_maxium={chunk_size_maxium}, "
|
||||
f"pre lookahead size={lookahead_size}.",
|
||||
"pre_lookahead_len")
|
||||
|
||||
blank_penalty = kwargs.get("blank_penalty", 0.0)
|
||||
sampling = kwargs.get("sampling", "greedy")
|
||||
prompt_dict = kwargs.get("prompt_dict", {})
|
||||
prompt_token = list(prompt_dict.get("prompt_token", [None, None]))
|
||||
prompt_audio = list(prompt_dict.get("prompt_audio", [None, None]))
|
||||
|
||||
ftype = self.text_embedding.weight.dtype
|
||||
if prompt_token[0] is None:
|
||||
prompt_token[0] = torch.zeros([1, 0, self.output_size], device=device, dtype=ftype)
|
||||
prompt_token[1] = torch.tensor([0], device=device, dtype=torch.long)
|
||||
if prompt_audio[0] is None:
|
||||
prompt_audio[0] = torch.zeros(
|
||||
[1, 0, self.fm_model.mel_extractor.num_mels],
|
||||
device=device, dtype=ftype
|
||||
)
|
||||
prompt_audio[1] = torch.tensor([0], device=device, dtype=torch.long)
|
||||
|
||||
# embed text inputs
|
||||
mask = (text != -1).float().unsqueeze(-1)
|
||||
text_emb = self.text_embedding(torch.clamp(text, min=0)) * mask
|
||||
text_emb_lengths = text_lengths
|
||||
|
||||
batch_size = text.shape[0]
|
||||
|
||||
prompt, prompt_lens, text_emb, text_emb_lengths = self.split_prompt(
|
||||
text_emb, text_emb_lengths, text, text_lengths
|
||||
)
|
||||
if "outside_prompt" in kwargs:
|
||||
prompt = kwargs["outside_prompt"].to(device)
|
||||
if "outside_prompt_lengths" in kwargs:
|
||||
prompt_lens = kwargs["outside_prompt_lengths"]
|
||||
else:
|
||||
prompt_lens = torch.tensor([prompt.shape[1]]).to(text_lengths)
|
||||
prompt = self.outside_prompt_poj(prompt)
|
||||
hint_once("use outside_prompt", "outside_prompt")
|
||||
|
||||
if xvec is not None:
|
||||
# using speaker embedding
|
||||
hint_once("using speaker embedding for slot.", "use_spk_emb")
|
||||
xvec = xvec[:, :xvec_lengths.max()]
|
||||
else:
|
||||
# textual prompt xvec
|
||||
hint_once("using textual prompt for slot.", "use_spk_emb")
|
||||
prompt_xvec = self.spk_aggregator(
|
||||
prompt, prompt_lens,
|
||||
self.spk_query.expand([batch_size, -1, -1]), torch.tensor([1] * batch_size).to(prompt_lens)
|
||||
)[0]
|
||||
xvec = self.prompt_xvec_proj(prompt_xvec)
|
||||
xvec_lengths = torch.tensor([1] * batch_size).to(text_lengths)
|
||||
|
||||
chunk_text_emb = text_emb
|
||||
chunk_text_emb_lengths = torch.tensor([chunk_text_emb.shape[1]], dtype=torch.long, device=device)
|
||||
|
||||
outs_tuple = self.text_encoder(chunk_text_emb, ilens=chunk_text_emb_lengths)
|
||||
text_enc = outs_tuple[0]
|
||||
text_enc_lens = chunk_text_emb_lengths
|
||||
|
||||
# forward AM model
|
||||
tokens, aligned_token_emb, aligned_token_lens = self.am_model.inference(
|
||||
text_enc, text_enc_lens,
|
||||
xvec, xvec_lengths,
|
||||
sampling=sampling,
|
||||
blank_penalty=blank_penalty,
|
||||
text_is_embedding=True,
|
||||
return_hidden=True,
|
||||
use_causal_prob=use_causal_prob,
|
||||
)
|
||||
token_hop_len, mel_hop_len = 0, 0
|
||||
if isinstance(tokens, tuple):
|
||||
tokens, fa_tokens = tokens
|
||||
token_hop_len = self.get_hop_lens(fa_tokens, lookahead_size)
|
||||
mel_hop_len = int(round(token_hop_len * self.fm_model.length_normalizer_ratio))
|
||||
|
||||
cur_token, feat = None, None
|
||||
# exclude empty tokens.
|
||||
if aligned_token_emb.shape[1] > prompt_token[0].shape[1]:
|
||||
cur_token = aligned_token_emb[:, prompt_token[0].shape[1]:]
|
||||
cur_token_len = aligned_token_lens - prompt_token[1]
|
||||
|
||||
# v2: excluding lookahead tokens for not-last packages
|
||||
if text[0, -1] != self.endofprompt_token_id+1:
|
||||
cur_token = cur_token[:, :cur_token.shape[1] - token_hop_len, :]
|
||||
cur_token_len = cur_token_len - token_hop_len
|
||||
|
||||
# forward FM model
|
||||
feat = self.fm_model.inference(
|
||||
cur_token, cur_token_len,
|
||||
xvec, xvec_lengths,
|
||||
prompt=dict(
|
||||
prompt_text=prompt_token,
|
||||
prompt_audio=prompt_audio,
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
feat = self.rms_rescale_feat(feat)
|
||||
print_token = tokens.cpu().squeeze().tolist()
|
||||
logging.info(f"valid_tokens: {print_token[:len(print_token) - token_hop_len]}, "
|
||||
f"pad_tokens: {print_token[len(print_token) - token_hop_len:]}.")
|
||||
|
||||
# v2: reback token and mel feat
|
||||
if text[0, -1] != self.endofprompt_token_id+1:
|
||||
text_reback = 2 if chunk_idx == 0 else 4
|
||||
token_hop_len_2 = self.get_hop_lens(fa_tokens, lookahead_size + text_reback)
|
||||
token_reback = token_hop_len_2 - token_hop_len
|
||||
cur_token = cur_token[:, :cur_token.shape[1] - token_reback, :]
|
||||
feat_reback = int(round(token_reback * self.fm_model.length_normalizer_ratio))
|
||||
feat = feat[:, :, :feat.shape[2] - feat_reback]
|
||||
|
||||
return cur_token, feat
|
||||
|
||||
@ -29,6 +29,7 @@ parser.add_argument("--ngpu", type=int, default=1, help="0 for cpu, 1 for gpu")
|
||||
parser.add_argument("--device", type=str, default="cuda", help="cuda, cpu")
|
||||
parser.add_argument("--ncpu", type=int, default=4, help="cpu cores")
|
||||
parser.add_argument("--return_sentence", action="store_true", help="return sentence or all_res")
|
||||
parser.add_argument("--no_vad", action="store_true", help="infer without vad")
|
||||
parser.add_argument(
|
||||
"--certfile",
|
||||
type=str,
|
||||
@ -78,6 +79,7 @@ audio_encoder_dir = snapshot_download("iic/SenseVoice", cache_dir=None, revision
|
||||
# audio_encoder_dir = "/nfs/yangyexin.yyx/init_model/iic/SenseVoiceModelscope_0712"
|
||||
device = "cuda:0"
|
||||
all_file_paths = [
|
||||
# "/nfs/yangyexin.yyx/init_model/s2tt/qwen2_7b_mmt_v15_20240910_streaming",
|
||||
"FunAudioLLM/qwen2_7b_mmt_v15_20240910_streaming",
|
||||
"FunAudioLLM/qwen2_7b_mmt_v15_20240902",
|
||||
"FunAudioLLM/qwen2_7b_mmt_v14_20240830",
|
||||
@ -91,7 +93,6 @@ llm_kwargs = {"num_beams": 1, "do_sample": False, "repetition_penalty": 1.3}
|
||||
UNFIX_LEN = 5
|
||||
MIN_LEN_PER_PARAGRAPH = 25
|
||||
MIN_LEN_SEC_AUDIO_FIX = 1.1
|
||||
MAX_ITER_PER_CHUNK = 20
|
||||
|
||||
ckpt_dir = all_file_paths[0]
|
||||
|
||||
@ -483,32 +484,51 @@ async def ws_serve(websocket, path):
|
||||
frames_asr.append(message)
|
||||
|
||||
# vad online
|
||||
try:
|
||||
speech_start_i, speech_end_i = await async_vad(websocket, message)
|
||||
except:
|
||||
print("error in vad")
|
||||
if speech_start_i != -1:
|
||||
if not args.no_vad:
|
||||
try:
|
||||
speech_start_i, speech_end_i = await async_vad(websocket, message)
|
||||
except:
|
||||
print("error in vad")
|
||||
if speech_start_i != -1:
|
||||
speech_start = True
|
||||
speech_end_i = -1
|
||||
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
|
||||
frames_pre = frames[-beg_bias:]
|
||||
frames_asr = []
|
||||
frames_asr.extend(frames_pre)
|
||||
else:
|
||||
speech_start = True
|
||||
beg_bias = (websocket.vad_pre_idx - speech_start_i) // duration_ms
|
||||
frames_pre = frames[-beg_bias:]
|
||||
speech_end_i = -1
|
||||
frames_asr = []
|
||||
frames_asr.extend(frames_pre)
|
||||
frames_asr.extend(frames)
|
||||
|
||||
# vad end
|
||||
if speech_end_i != -1 or not websocket.is_speaking:
|
||||
audio_in = b"".join(frames_asr)
|
||||
try:
|
||||
await streaming_transcribe(
|
||||
websocket, audio_in, is_vad_end=True, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"error in streaming, {e}")
|
||||
print(f"error in streaming, {websocket.streaming_state}")
|
||||
if speech_end_i != -1:
|
||||
audio_in = b"".join(frames_asr)
|
||||
try:
|
||||
await streaming_transcribe(
|
||||
websocket, audio_in, is_vad_end=True, asr_prompt=asr_prompt, s2tt_prompt=s2tt_prompt
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"error in streaming, {e}")
|
||||
print(f"error in streaming, {websocket.streaming_state}")
|
||||
frames_asr = []
|
||||
speech_start = False
|
||||
websocket.streaming_state["previous_asr_text"] = ""
|
||||
websocket.streaming_state["previous_s2tt_text"] = ""
|
||||
if not websocket.is_speaking:
|
||||
message = json.dumps(
|
||||
{
|
||||
"mode": "online",
|
||||
"asr_text": websocket.streaming_state["onscreen_asr_res"] + "<em></em>",
|
||||
"s2tt_text": websocket.streaming_state["onscreen_s2tt_res"] + "<em></em>",
|
||||
"wav_name": websocket.wav_name,
|
||||
"is_final": websocket.is_speaking,
|
||||
"is_sentence_end": True,
|
||||
}
|
||||
)
|
||||
await websocket.send(message)
|
||||
await clear_websocket()
|
||||
if args.return_sentence:
|
||||
websocket.streaming_state["previous_vad_onscreen_asr_text"] = ""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user