mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
add audio decoding
This commit is contained in:
parent
2ab9f44113
commit
05acd675ec
@ -6,6 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
from funasr.models.scama.utils import sequence_mask
|
from funasr.models.scama.utils import sequence_mask
|
||||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||||
@ -2734,6 +2735,8 @@ class LLMASR5(nn.Module):
|
|||||||
for i in range(token_num):
|
for i in range(token_num):
|
||||||
hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32)
|
hidden_states_out[0, i, :] = hidden_states[1][-1][0, 0, :].to(torch.float32)
|
||||||
|
|
||||||
|
speech_tokens = audio_decode(hidden_states)
|
||||||
|
|
||||||
# generated_ids = [
|
# generated_ids = [
|
||||||
# output_ids[len(input_id) :]
|
# output_ids[len(input_id) :]
|
||||||
# for input_id, output_ids in zip(input_ids, generated_ids)
|
# for input_id, output_ids in zip(input_ids, generated_ids)
|
||||||
@ -2763,3 +2766,110 @@ class LLMASR5(nn.Module):
|
|||||||
ibest_writer["text_tn"][key[0]] = response_clean
|
ibest_writer["text_tn"][key[0]] = response_clean
|
||||||
|
|
||||||
return results, meta_data
|
return results, meta_data
|
||||||
|
|
||||||
|
def audio_decode(
|
||||||
|
self,
|
||||||
|
text: torch.Tensor,
|
||||||
|
text_lengths: torch.Tensor,
|
||||||
|
min_length=None,
|
||||||
|
max_length: int = 30 * 25,
|
||||||
|
infer_cfg_ratio=None,
|
||||||
|
decoding_length=None,
|
||||||
|
):
|
||||||
|
# 1. encode text
|
||||||
|
text = self.audio_decoder_in_proj(text)
|
||||||
|
device = text.device
|
||||||
|
out_tokens = []
|
||||||
|
sos_eos_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_sos_eos]], dtype=torch.int64, device=device))
|
||||||
|
task_id_emb = self.audio_decoder_embedding(torch.tensor([[self.ad_task_id]], dtype=torch.int64, device=device))
|
||||||
|
prompt = torch.cat([sos_eos_emb, text, task_id_emb], dim=1)
|
||||||
|
state, cfg_state = None, None
|
||||||
|
for i in range(max_length):
|
||||||
|
if len(out_tokens) > 0:
|
||||||
|
codec_prompt = torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
||||||
|
codec_lengths = torch.tensor([len(out_tokens)], dtype=torch.int64, device=device)
|
||||||
|
# if any quantizer output is eos
|
||||||
|
if torch.any(codec_prompt[:, -1] == (self.codebook_size+self.sos_eos)):
|
||||||
|
break
|
||||||
|
seq_input, _ = self.prepare_audio_decoder_io(
|
||||||
|
text, text_lengths,
|
||||||
|
codec_prompt, codec_lengths,
|
||||||
|
need_targets=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
seq_input, _ = self.prepare_audio_decoder_io(
|
||||||
|
text, text_lengths, None, None,
|
||||||
|
need_targets=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# use state for speedup
|
||||||
|
pred, (state, _) = self.audio_decoder.score(
|
||||||
|
seq_input[0],
|
||||||
|
state,
|
||||||
|
prompt[0]
|
||||||
|
)
|
||||||
|
if infer_cfg_ratio is not None:
|
||||||
|
cond_len = prompt[0].shape[0]
|
||||||
|
cfg_pred, (cfg_state, _) = self.audio_decoder.score(
|
||||||
|
seq_input[0][cond_len-1:],
|
||||||
|
cfg_state,
|
||||||
|
prompt[0][cond_len-1:]
|
||||||
|
)
|
||||||
|
pred = (1 + infer_cfg_ratio) * pred - infer_cfg_ratio * cfg_pred
|
||||||
|
|
||||||
|
# sampling all `nq` token ids
|
||||||
|
pred = pred.reshape(self.predict_nq, -1)
|
||||||
|
# normalize scores
|
||||||
|
pred = torch.log_softmax(pred, dim=-1)
|
||||||
|
if min_length is not None and i < min_length:
|
||||||
|
pred[:, self.codebook_size + self.ad_sos_eos] = float(np.finfo(np.float32).min)
|
||||||
|
top_ids = []
|
||||||
|
for k in range(self.predict_nq):
|
||||||
|
top_ids.append(self.ras_sampling(pred[k], out_tokens)[0].item())
|
||||||
|
out_tokens.append(top_ids)
|
||||||
|
|
||||||
|
# remove eos token
|
||||||
|
hit_eos = False
|
||||||
|
if torch.any(torch.tensor(out_tokens[-1], dtype=torch.int64) == self.codebook_size+self.ad_sos_eos):
|
||||||
|
hit_eos = True
|
||||||
|
out_tokens = out_tokens[:-1]
|
||||||
|
|
||||||
|
if decoding_length is None:
|
||||||
|
return torch.tensor([out_tokens], dtype=torch.int64, device=device)
|
||||||
|
else:
|
||||||
|
return torch.tensor([out_tokens], dtype=torch.int64, device=device), hit_eos
|
||||||
|
|
||||||
|
# Repetition Aware Sampling in VALL-E 2
|
||||||
|
def ras_sampling(
|
||||||
|
self,
|
||||||
|
weighted_scores, decoded_tokens, *,
|
||||||
|
top_p=0.8, top_k=25, win_size=10, tau_r=0.1
|
||||||
|
):
|
||||||
|
top_ids = self.nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||||
|
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(top_ids) == top_ids).sum().item()
|
||||||
|
if rep_num >= win_size * tau_r:
|
||||||
|
top_ids = self.random_sampling(weighted_scores)
|
||||||
|
|
||||||
|
return top_ids
|
||||||
|
|
||||||
|
def nucleus_sampling(self, weighted_scores, top_p=0.8, top_k=25):
|
||||||
|
prob, indices = [], []
|
||||||
|
cum_prob = 0.0
|
||||||
|
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
|
||||||
|
for i in range(len(sorted_idx)):
|
||||||
|
# sampling both top-p and numbers.
|
||||||
|
if cum_prob < top_p and len(prob) < top_k:
|
||||||
|
cum_prob += sorted_value[i]
|
||||||
|
prob.append(sorted_value[i])
|
||||||
|
indices.append(sorted_idx[i])
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
prob = torch.tensor(prob).to(weighted_scores)
|
||||||
|
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
||||||
|
sampling_ids = prob.multinomial(1, replacement=True)
|
||||||
|
top_ids = indices[sampling_ids]
|
||||||
|
return top_ids
|
||||||
|
|
||||||
|
def random_sampling(self, weighted_scores):
|
||||||
|
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
||||||
|
return top_ids
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user