From 63800cb852b063ea2001bbd716b8c0a15df3b3a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 4 Jul 2024 11:07:27 +0800 Subject: [PATCH] update --- funasr/models/llm_asr/model.py | 81 ++++++++++++++++++++-------------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 6358a8ddc..89ada01ab 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -2731,19 +2731,27 @@ class LLMASR5(nn.Module): hidden_states_out = torch.zeros((1, token_num, 3584), dtype=torch.float32).to( inputs_embeds.device ) - + hidden_states_out_len = torch.tensor( + [ + token_num, + ], + dtype=torch.int32, + ).to(inputs_embeds.device) for i in range(token_num): hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32) - speech_tokens = audio_decode(hidden_states) + speech_tokens = self.audio_decode( + hidden_states_out, hidden_states_out_len + ) # 1xl: 2,10,1023 + sequences = generated_ids["sequences"] # generated_ids = [ # output_ids[len(input_id) :] # for input_id, output_ids in zip(input_ids, generated_ids) # ] - # response = tokenizer.batch_decode( - # generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) - # )[0] + response = tokenizer.batch_decode( + sequences, skip_special_tokens=kwargs.get("skip_special_tokens", True) + )[0] loss = None @@ -2755,33 +2763,49 @@ class LLMASR5(nn.Module): results = [] response_clean = re.sub("[^\w\s\u3000\u4e00-\u9fff]+", "", response) - result_i = {"key": key[0], "text": response, "text_tn": response_clean, "label": label} + result_i = { + "key": key[0], + "text": response, + "text_tn": response_clean, + "label": label, + "speech_tokens": speech_tokens, + } if loss is not None: result_i["loss"] = loss results.append(result_i) + speech_tokens_out = "<|startofspeech|>" + for i in range(speech_tokens.shape[-1]): + tmp = speech_tokens[0, i].item() + speech_tokens_out += f"<|c{tmp}|>" + speech_tokens_out += "<|endofspeech|><|im_end|>" if ibest_writer is not None: ibest_writer["text"][key[0]] = response.replace("\n", " ") ibest_writer["label"][key[0]] = label.replace("\n", " ") ibest_writer["text_tn"][key[0]] = response_clean + ibest_writer["speech_tokens"][key[0]] = speech_tokens_out return results, meta_data def audio_decode( - self, - text: torch.Tensor, - text_lengths: torch.Tensor, - min_length=None, - max_length: int = 30 * 25, - infer_cfg_ratio=None, - decoding_length=None, + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + min_length=None, + max_length: int = 30 * 25, + infer_cfg_ratio=None, + decoding_length=None, ): # 1. encode text text = self.audio_decoder_in_proj(text) device = text.device out_tokens = [] - sos_eos_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device)) - task_id_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device)) + sos_eos_emb = self.audio_decoder_embedding( + torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device) + ) + task_id_emb = self.audio_decoder_embedding( + torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device) + ) prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1) state, cfg_state = None, None for i in range(max_length): @@ -2789,31 +2813,22 @@ class LLMASR5(nn.Module): codec_prompt = torch.tensor([out_tokens], dtype=torch.int64, device=device) codec_lengths = torch.tensor([len(out_tokens)], dtype=torch.int64, device=device) # if any quantizer output is eos - if torch.any(codec_prompt[:, -1] == (self.codebook_size+self.sos_eos)): + if torch.any(codec_prompt[:, -1] == (self.codebook_size + self.ad_sos_eos)): break seq_input, _ = self.prepare_audio_decoder_io( - text, text_lengths, - codec_prompt, codec_lengths, - need_targets=False + text, text_lengths, codec_prompt, codec_lengths, need_targets=False ) else: seq_input, _ = self.prepare_audio_decoder_io( - text, text_lengths, None, None, - need_targets=False + text, text_lengths, None, None, need_targets=False ) # use state for speedup - pred, (state, _) = self.audio_decoder.score( - seq_input[0], - state, - prompt[0] - ) + pred, (state, _) = self.audio_decoder.score(seq_input[0], state, prompt[0]) if infer_cfg_ratio is not None: cond_len = prompt[0].shape[0] cfg_pred, (cfg_state, _) = self.audio_decoder.score( - seq_input[0][cond_len-1:], - cfg_state, - prompt[0][cond_len-1:] + seq_input[0][cond_len - 1 :], cfg_state, prompt[0][cond_len - 1 :] ) pred = (1 + infer_cfg_ratio) * pred - infer_cfg_ratio * cfg_pred @@ -2830,7 +2845,9 @@ class LLMASR5(nn.Module): # remove eos token hit_eos = False - if torch.any(torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size+self.ad_sos_eos): + if torch.any( + torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size + self.ad_sos_eos + ): hit_eos = True out_tokens = out_tokens[:-1] @@ -2841,9 +2858,7 @@ class LLMASR5(nn.Module): # Repetition Aware Sampling in VALL-E 2 def ras_sampling( - self, - weighted_scores, decoded_tokens, *, - top_p=0.8, top_k=25, win_size=10, tau_r=0.1 + self, weighted_scores, decoded_tokens, *, top_p=0.8, top_k=25, win_size=10, tau_r=0.1 ): top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item()