mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
speech2speech
This commit is contained in:
parent
6c51cc2e0a
commit
d46c2df2e1
@ -2204,6 +2204,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
load_in_8bit=None,
|
||||
device_map=None,
|
||||
use_cache=None,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
else:
|
||||
import os
|
||||
@ -2807,7 +2808,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
|
||||
):
|
||||
label = contents["assistant"][-1]
|
||||
self.llm = self.llm.to(dtype_map[llm_dtype])
|
||||
self.llm.to(dtype_map[llm_dtype])
|
||||
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
|
||||
llm_kwargs = kwargs.get("llm_kwargs", {})
|
||||
if not kwargs.get("tearchforing", False):
|
||||
@ -2820,8 +2821,7 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
output_scores=True,
|
||||
**llm_kwargs,
|
||||
)
|
||||
# hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
|
||||
hidden_states = generated_ids["hidden_states"].hidden_states[-1].float()
|
||||
|
||||
# TODO: get llm_cur_kv_cache
|
||||
|
||||
target_ids = generated_ids["sequences"]
|
||||
@ -2871,30 +2871,40 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
|
||||
# tts related inference, require the kv cache of llm last layer for only the current inputs
|
||||
# TODO: select kv cache of the current turn inputs
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
attention_mask = batch.get("attention_mask", None)
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=None,
|
||||
labels=None,
|
||||
**llm_kwargs,
|
||||
)
|
||||
hidden_states = model_outputs.hidden_states[-1].float()
|
||||
# hidden_states = generated_ids[
|
||||
# "hidden_states"
|
||||
# ] # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
|
||||
|
||||
token_num = len(hidden_states)
|
||||
hidden_states_select = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
for i in range(token_num):
|
||||
hidden_states_select[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32)
|
||||
# token_num = len(hidden_states)
|
||||
# hidden_states_select = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
|
||||
# inputs_embeds.device
|
||||
# )
|
||||
#
|
||||
# for i in range(token_num):
|
||||
# hidden_states_select[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32)
|
||||
|
||||
llm_cur_kv_cache, llm_cur_kv_cache_len = None, None
|
||||
|
||||
input_mask_beg = outputs.get("input_mask_beg")
|
||||
input_mask_beg = batch.get("input_mask_beg")[0][None, :]
|
||||
input_mask_beg[input_mask_beg < 0] = 0
|
||||
input_mask = outputs.get("input_mask")
|
||||
input_mask = batch.get("input_mask")[0][None, :]
|
||||
input_mask[input_mask < 0] = 0
|
||||
|
||||
for turn_id_cum in range(input_mask.shape[0]):
|
||||
beg = input_mask_beg[turn_id_cum].sum(-1)
|
||||
end = input_mask[turn_id_cum].sum(-1)
|
||||
llm_cur_kv_cache = hidden_states_select[:, beg:end, :]
|
||||
llm_cur_kv_cache = hidden_states[:, beg:end, :]
|
||||
llm_cur_kv_cache_len = torch.tensor(
|
||||
[
|
||||
end - beg,
|
||||
@ -2916,14 +2926,15 @@ class LLMASRXvecSlotTTS(nn.Module):
|
||||
return results, meta_data
|
||||
|
||||
def generate_speech(self, text, llm_cur_kv_cache, llm_cur_kv_cache_len, llm_dtype):
|
||||
self.tts_text_tokenizer = self.tts_text_tokenizer.to(llm_dtype)
|
||||
self.vocoder = self.vocoder.to(llm_dtype)
|
||||
# self.tts_text_tokenizer = self.tts_text_tokenizer
|
||||
self.vocoder.to(llm_dtype)
|
||||
device = llm_cur_kv_cache.device
|
||||
# tokenize text
|
||||
text_token = self.tts_text_tokenizer.text2tokens(f"<|endofprompt|><|sil|>{text}<|sil|>")
|
||||
text_token = torch.tensor([text_token], llm_dtype, device)
|
||||
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], torch.long, device)
|
||||
# e2e tts model forward
|
||||
self.tts_model.to(llm_dtype)
|
||||
speech_tokens, mel_feats = self.tts_model.inference(
|
||||
text_token,
|
||||
text_token_len,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user