speech2speech

This commit is contained in:
游雁 2024-09-11 16:15:35 +08:00
parent 6c51cc2e0a
commit d46c2df2e1

View File

@ -2204,6 +2204,7 @@ class LLMASRXvecSlotTTS(nn.Module):
load_in_8bit=None,
device_map=None,
use_cache=None,
output_hidden_states=True,
)
else:
import os
@ -2807,7 +2808,7 @@ class LLMASRXvecSlotTTS(nn.Module):
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
):
label = contents["assistant"][-1]
self.llm = self.llm.to(dtype_map[llm_dtype])
self.llm.to(dtype_map[llm_dtype])
inputs_embeds = inputs_embeds.to(dtype_map[llm_dtype])
llm_kwargs = kwargs.get("llm_kwargs", {})
if not kwargs.get("tearchforing", False):
@ -2820,8 +2821,7 @@ class LLMASRXvecSlotTTS(nn.Module):
output_scores=True,
**llm_kwargs,
)
# hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
hidden_states = generated_ids["hidden_states"].hidden_states[-1].float()
# TODO: get llm_cur_kv_cache
target_ids = generated_ids["sequences"]
@ -2871,30 +2871,40 @@ class LLMASRXvecSlotTTS(nn.Module):
# tts related inference, require the kv cache of llm last layer for only the current inputs
# TODO: select kv cache of the current turn inputs
import pdb
pdb.set_trace()
attention_mask = batch.get("attention_mask", None)
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=None,
labels=None,
**llm_kwargs,
)
hidden_states = model_outputs.hidden_states[-1].float()
# hidden_states = generated_ids[
# "hidden_states"
# ] # hidden_states: (t1, t2, ..., tn, ..., tN), tn=(l1, l2, ..., ln, ..., lN), ln: shape: 1x1x3584
token_num = len(hidden_states)
hidden_states_select = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
inputs_embeds.device
)
for i in range(token_num):
hidden_states_select[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32)
# token_num = len(hidden_states)
# hidden_states_select = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
# inputs_embeds.device
# )
#
# for i in range(token_num):
# hidden_states_select[0, i, :] = hidden_states[i][-1][0, 0, :].to(torch.float32)
llm_cur_kv_cache, llm_cur_kv_cache_len = None, None
input_mask_beg = outputs.get("input_mask_beg")
input_mask_beg = batch.get("input_mask_beg")[0][None, :]
input_mask_beg[input_mask_beg < 0] = 0
input_mask = outputs.get("input_mask")
input_mask = batch.get("input_mask")[0][None, :]
input_mask[input_mask < 0] = 0
for turn_id_cum in range(input_mask.shape[0]):
beg = input_mask_beg[turn_id_cum].sum(-1)
end = input_mask[turn_id_cum].sum(-1)
llm_cur_kv_cache = hidden_states_select[:, beg:end, :]
llm_cur_kv_cache = hidden_states[:, beg:end, :]
llm_cur_kv_cache_len = torch.tensor(
[
end - beg,
@ -2916,14 +2926,15 @@ class LLMASRXvecSlotTTS(nn.Module):
return results, meta_data
def generate_speech(self, text, llm_cur_kv_cache, llm_cur_kv_cache_len, llm_dtype):
self.tts_text_tokenizer = self.tts_text_tokenizer.to(llm_dtype)
self.vocoder = self.vocoder.to(llm_dtype)
# self.tts_text_tokenizer = self.tts_text_tokenizer
self.vocoder.to(llm_dtype)
device = llm_cur_kv_cache.device
# tokenize text
text_token = self.tts_text_tokenizer.text2tokens(f"<|endofprompt|><|sil|>{text}<|sil|>")
text_token = torch.tensor([text_token], llm_dtype, device)
text_token = torch.tensor([text_token], dtype=torch.long, device=device)
text_token_len = torch.tensor([text_token.shape[1]], torch.long, device)
# e2e tts model forward
self.tts_model.to(llm_dtype)
speech_tokens, mel_feats = self.tts_model.inference(
text_token,
text_token_len,