This commit is contained in:
游雁 2024-07-04 11:07:27 +08:00
parent 05acd675ec
commit 63800cb852

View File

@ -2731,19 +2731,27 @@ class LLMASR5(nn.Module):
hidden_states_out = torch.zeros((1, token_num, 3584), dtype=torch.float32).to(
inputs_embeds.device
)
hidden_states_out_len = torch.tensor(
[
token_num,
],
dtype=torch.int32,
).to(inputs_embeds.device)
for i in range(token_num):
hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32)
speech_tokens = audio_decode(hidden_states)
speech_tokens = self.audio_decode(
hidden_states_out, hidden_states_out_len
) # 1xl: 2,10,1023
sequences = generated_ids["sequences"]
# generated_ids = [
# output_ids[len(input_id) :]
# for input_id, output_ids in zip(input_ids, generated_ids)
# ]
# response = tokenizer.batch_decode(
# generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
# )[0]
response = tokenizer.batch_decode(
sequences, skip_special_tokens=kwargs.get("skip_special_tokens", True)
)[0]
loss = None
@ -2755,33 +2763,49 @@ class LLMASR5(nn.Module):
results = []
response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response)
result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label}
result_i = {
"key": key[0],
"text": response,
"text_tn": response_clean,
"label": label,
"speech_tokens": speech_tokens,
}
if loss is not None:
result_i["loss"] = loss
results.append(result_i)
speech_tokens_out = "<|startofspeech|>"
for i in range(speech_tokens.shape[-1]):
tmp = speech_tokens[0, i].item()
speech_tokens_out += f"<|c{tmp}|>"
speech_tokens_out += "<|endofspeech|><|im_end|>"
if ibest_writer is not None:
ibest_writer["text"][key[0]] = response.replace("\n", " ")
ibest_writer["label"][key[0]] = label.replace("\n", " ")
ibest_writer["text_tn"][key[0]] = response_clean
ibest_writer["speech_tokens"][key[0]] = speech_tokens_out
return results, meta_data
def audio_decode(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
min_length=None,
max_length: int = 30 * 25,
infer_cfg_ratio=None,
decoding_length=None,
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
min_length=None,
max_length: int = 30 * 25,
infer_cfg_ratio=None,
decoding_length=None,
):
# 1. encode text
text = self.audio_decoder_in_proj(text)
device = text.device
out_tokens = []
sos_eos_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device))
task_id_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device))
sos_eos_emb = self.audio_decoder_embedding(
torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device)
)
task_id_emb = self.audio_decoder_embedding(
torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device)
)
prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1)
state, cfg_state = None, None
for i in range(max_length):
@ -2789,31 +2813,22 @@ class LLMASR5(nn.Module):
codec_prompt = torch.tensor([out_tokens], dtype=torch.int64, device=device)
codec_lengths = torch.tensor([len(out_tokens)], dtype=torch.int64, device=device)
# if any quantizer output is eos
if torch.any(codec_prompt[:, -1] == (self.codebook_size+self.sos_eos)):
if torch.any(codec_prompt[:, -1] == (self.codebook_size + self.ad_sos_eos)):
break
seq_input, _ = self.prepare_audio_decoder_io(
text, text_lengths,
codec_prompt, codec_lengths,
need_targets=False
text, text_lengths, codec_prompt, codec_lengths, need_targets=False
)
else:
seq_input, _ = self.prepare_audio_decoder_io(
text, text_lengths, None, None,
need_targets=False
text, text_lengths, None, None, need_targets=False
)
# use state for speedup
pred, (state, _) = self.audio_decoder.score(
seq_input[0],
state,
prompt[0]
)
pred, (state, _) = self.audio_decoder.score(seq_input[0], state, prompt[0])
if infer_cfg_ratio is not None:
cond_len = prompt[0].shape[0]
cfg_pred, (cfg_state, _) = self.audio_decoder.score(
seq_input[0][cond_len-1:],
cfg_state,
prompt[0][cond_len-1:]
seq_input[0][cond_len - 1 :], cfg_state, prompt[0][cond_len - 1 :]
)
pred = (1 + infer_cfg_ratio) * pred - infer_cfg_ratio * cfg_pred
@ -2830,7 +2845,9 @@ class LLMASR5(nn.Module):
# remove eos token
hit_eos = False
if torch.any(torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size+self.ad_sos_eos):
if torch.any(
torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size + self.ad_sos_eos
):
hit_eos = True
out_tokens = out_tokens[:-1]
@ -2841,9 +2858,7 @@ class LLMASR5(nn.Module):
# Repetition Aware Sampling in VALL-E 2
def ras_sampling(
self,
weighted_scores, decoded_tokens, *,
top_p=0.8, top_k=25, win_size=10, tau_r=0.1
self, weighted_scores, decoded_tokens, *, top_p=0.8, top_k=25, win_size=10, tau_r=0.1
):
top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item()