mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
b6aad84db6
commit
4784baf2af
@ -1408,7 +1408,7 @@ class LLMASR4(nn.Module):
|
||||
return results, meta_data
|
||||
|
||||
|
||||
@tables.register("model_classes", "LLMASR5")
|
||||
# @tables.register("model_classes", "LLMASR5")
|
||||
class LLMASR5(nn.Module):
|
||||
""" """
|
||||
|
||||
@ -2011,41 +2011,19 @@ class LLMASR5(nn.Module):
|
||||
""" """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
normalize: str = None,
|
||||
normalize_conf: dict = None,
|
||||
audio_encoder: str = None,
|
||||
audio_encoder_conf: dict = None,
|
||||
audio_adaptor: str = None,
|
||||
audio_adaptor_conf: dict = None,
|
||||
decoder: str = None,
|
||||
decoder_conf: dict = None,
|
||||
ctc: str = None,
|
||||
ctc_conf: dict = None,
|
||||
ctc_weight: float = 0.5,
|
||||
llm: str = None,
|
||||
llm_conf: dict = None,
|
||||
input_size: int = 80,
|
||||
vocab_size: int = -1,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
sos: int = 1,
|
||||
eos: int = 2,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
report_cer: bool = True,
|
||||
report_wer: bool = True,
|
||||
sym_space: str = "<space>",
|
||||
sym_blank: str = "<blank>",
|
||||
# extract_feats_in_collect_stats: bool = True,
|
||||
share_embedding: bool = False,
|
||||
# preencoder: Optional[AbsPreEncoder] = None,
|
||||
# postencoder: Optional[AbsPostEncoder] = None,
|
||||
audio_decoder: str = None,
|
||||
audio_decoder_conf: dict = None,
|
||||
**kwargs,
|
||||
self,
|
||||
audio_encoder: str = None,
|
||||
audio_encoder_conf: dict = None,
|
||||
audio_adaptor: str = None,
|
||||
audio_adaptor_conf: dict = None,
|
||||
llm: str = None,
|
||||
llm_conf: dict = None,
|
||||
input_size: int = 80,
|
||||
lsm_weight: float = 0.0,
|
||||
length_normalized_loss: bool = False,
|
||||
audio_decoder: str = None,
|
||||
audio_decoder_conf: dict = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
@ -2082,7 +2060,7 @@ class LLMASR5(nn.Module):
|
||||
idx = re.search(r"\.\d+\.", name)
|
||||
if idx is not None:
|
||||
beg, end = idx.regs[0]
|
||||
layer_id = int(name[beg + 1: end - 1])
|
||||
layer_id = int(name[beg + 1 : end - 1])
|
||||
if layer_id < freeze_layer_num:
|
||||
param.requires_grad = False
|
||||
elif "ln_post." not in name:
|
||||
@ -2134,6 +2112,8 @@ class LLMASR5(nn.Module):
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.beam_search = None
|
||||
|
||||
self.eos = kwargs.get("eos", 151645)
|
||||
|
||||
# audio decoder related
|
||||
self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf)
|
||||
self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit)
|
||||
@ -2148,16 +2128,12 @@ class LLMASR5(nn.Module):
|
||||
def build_audio_decoder(self, name, conf):
|
||||
if name == "transformer":
|
||||
from funasr.models.llm_asr.transformer_lm import TransformerEmbedLM
|
||||
|
||||
if "text_vocab_size" in conf:
|
||||
lm_model = TransformerEmbedLM(
|
||||
vocab_size=self.lm_out_voc_size,
|
||||
**conf
|
||||
)
|
||||
lm_model = TransformerEmbedLM(vocab_size=self.lm_out_voc_size, **conf)
|
||||
else:
|
||||
lm_model = TransformerEmbedLM(
|
||||
vocab_size=self.lm_out_voc_size,
|
||||
text_vocab_size=self.lm_out_voc_size,
|
||||
**conf
|
||||
vocab_size=self.lm_out_voc_size, text_vocab_size=self.lm_out_voc_size, **conf
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Unknown codec decoder type {name}")
|
||||
@ -2175,30 +2151,35 @@ class LLMASR5(nn.Module):
|
||||
return self.codec_embedder(codec * mask).sum(dim=-2) * mask
|
||||
|
||||
def prepare_audio_decoder_io(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
codec: Optional[torch.Tensor] = None,
|
||||
codec_lengths: Optional[torch.Tensor] = None,
|
||||
need_targets: bool = True,
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
codec: Optional[torch.Tensor] = None,
|
||||
codec_lengths: Optional[torch.Tensor] = None,
|
||||
need_targets: bool = True,
|
||||
):
|
||||
"""build inputs and targets for language model
|
||||
|
||||
Normally, this function is called in batchify_nll.
|
||||
Args:
|
||||
text: (Batch, Length, Dim)
|
||||
text_lengths: (Batch,)
|
||||
codec: (Batch, Length)
|
||||
codec_lengths: (Batch,)
|
||||
need_targets: bool, whether provide targets
|
||||
"""
|
||||
Normally, this function is called in batchify_nll.
|
||||
Args:
|
||||
text: (Batch, Length, Dim)
|
||||
text_lengths: (Batch,)
|
||||
codec: (Batch, Length)
|
||||
codec_lengths: (Batch,)
|
||||
need_targets: bool, whether provide targets
|
||||
"""
|
||||
|
||||
if need_targets:
|
||||
assert codec is not None and codec_lengths is not None, \
|
||||
"need_target=True, but codec or codec_length is None"
|
||||
assert (
|
||||
codec is not None and codec_lengths is not None
|
||||
), "need_target=True, but codec or codec_length is None"
|
||||
|
||||
sos_eos_emb = self.audio_decoder_embedding(torch.tensor([self.ad_sos_eos], dtype=torch.int64, device=text.device))
|
||||
task_id_emb = self.audio_decoder_embedding(torch.tensor([self.ad_task_id], dtype=torch.int64, device=text.device))
|
||||
sos_eos_emb = self.audio_decoder_embedding(
|
||||
torch.tensor([self.ad_sos_eos], dtype=torch.int64, device=text.device)
|
||||
)
|
||||
task_id_emb = self.audio_decoder_embedding(
|
||||
torch.tensor([self.ad_task_id], dtype=torch.int64, device=text.device)
|
||||
)
|
||||
codec_emb = None
|
||||
if codec is not None and codec_lengths is not None:
|
||||
codec_emb = self.calc_dense_vector(codec, codec_lengths)
|
||||
@ -2206,7 +2187,7 @@ class LLMASR5(nn.Module):
|
||||
for i, text_len in enumerate(text_lengths):
|
||||
one_input = [sos_eos_emb, text[i, :text_len], task_id_emb]
|
||||
if codec_emb is not None:
|
||||
one_input.append(codec_emb[i, :codec_lengths[i]])
|
||||
one_input.append(codec_emb[i, : codec_lengths[i]])
|
||||
inputs_list.append(torch.cat(one_input, dim=0))
|
||||
llm_inputs = pad_list(inputs_list, 0.0)
|
||||
llm_lengths = text_lengths + 2
|
||||
@ -2217,7 +2198,9 @@ class LLMASR5(nn.Module):
|
||||
return llm_inputs, llm_lengths
|
||||
|
||||
bb, tt = text.shape[0], codec_lengths.max() + 1
|
||||
llm_targets = -1 * torch.ones([bb, tt, self.predict_nq], dtype=torch.int64, device=text.device)
|
||||
llm_targets = -1 * torch.ones(
|
||||
[bb, tt, self.predict_nq], dtype=torch.int64, device=text.device
|
||||
)
|
||||
for i, codec_len in enumerate(codec_lengths):
|
||||
llm_targets[i, :codec_len] = codec[i, :codec_len]
|
||||
llm_targets[i, codec_len] = self.codebook_size + self.sos_eos
|
||||
@ -2242,36 +2225,33 @@ class LLMASR5(nn.Module):
|
||||
"""
|
||||
batch_size = text.size(0)
|
||||
# For data parallel
|
||||
text = text[:, :text_lengths.max()]
|
||||
codec = codec[:, :codec_lengths.max()]
|
||||
text = text[:, : text_lengths.max()]
|
||||
codec = codec[:, : codec_lengths.max()]
|
||||
text = self.audio_decoder_in_proj(text)
|
||||
|
||||
# build inputs and targets for language model
|
||||
with autocast(False):
|
||||
(sequence, target), (x_lengths, y_lengths) = self.prepare_audio_decoder_io(
|
||||
text, text_lengths,
|
||||
codec, codec_lengths,
|
||||
need_targets=True
|
||||
text, text_lengths, codec, codec_lengths, need_targets=True
|
||||
)
|
||||
|
||||
# 2a. Forward Language model
|
||||
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
|
||||
sequence = sequence[:, :x_lengths.max()]
|
||||
target = target[:, :y_lengths.max()]
|
||||
y, _ = self.audio_decoder(sequence, x_lengths, text_lengths+1)
|
||||
sequence = sequence[:, : x_lengths.max()]
|
||||
target = target[:, : y_lengths.max()]
|
||||
y, _ = self.audio_decoder(sequence, x_lengths, text_lengths + 1)
|
||||
bb, tt = y.shape[0], y.shape[1]
|
||||
y = y.reshape(bb, tt, self.predict_nq, -1)
|
||||
# 2b. Extract real logits
|
||||
logits_list = []
|
||||
for i, (text_len, codec_len) in enumerate(zip(text_lengths, codec_lengths)):
|
||||
logits_list.append(y[i, text_len + 1:text_len + 2 + codec_len])
|
||||
logits_list.append(y[i, text_len + 1 : text_len + 2 + codec_len])
|
||||
logits = pad_list(logits_list, 0.0)
|
||||
|
||||
# 3. Calc negative log likelihood
|
||||
tt = logits.shape[1]
|
||||
nll = self.criterion_ce(
|
||||
logits.reshape(bb, tt * self.predict_nq, -1),
|
||||
target.reshape(bb, tt * self.predict_nq)
|
||||
logits.reshape(bb, tt * self.predict_nq, -1), target.reshape(bb, tt * self.predict_nq)
|
||||
)
|
||||
nll = nll.sum(-1)
|
||||
# nll: (BxL,) -> (BxL,)
|
||||
@ -2279,18 +2259,18 @@ class LLMASR5(nn.Module):
|
||||
# nll: (BxL,) -> (B, L)
|
||||
nll = nll.reshape(batch_size, -1).reshape(batch_size, tt, self.predict_nq)
|
||||
|
||||
return nll, logits, target, codec_lengths+1
|
||||
return nll, logits, target, codec_lengths + 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
labels_ids: torch.Tensor,
|
||||
fbank_beg: torch.Tensor,
|
||||
fbank_mask: torch.Tensor,
|
||||
**kwargs,
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
labels_ids: torch.Tensor,
|
||||
fbank_beg: torch.Tensor,
|
||||
fbank_mask: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||
"""Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
@ -2334,7 +2314,7 @@ class LLMASR5(nn.Module):
|
||||
|
||||
try:
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, :
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
except Exception as e:
|
||||
#
|
||||
@ -2347,13 +2327,13 @@ class LLMASR5(nn.Module):
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, :
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
|
||||
speech_idx += 1
|
||||
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
|
||||
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
|
||||
):
|
||||
labels_ids[labels_ids == -1] = -100
|
||||
attention_mask[attention_mask < 0] = 0
|
||||
@ -2364,6 +2344,47 @@ class LLMASR5(nn.Module):
|
||||
)
|
||||
loss = model_outputs.loss
|
||||
|
||||
codec = kwargs.get("codec")
|
||||
codec_len = kwargs.get("codec_len")
|
||||
if len(codec_len.size()) > 1:
|
||||
codec_len = codec_len[:, 0]
|
||||
hidden_states = model_outputs.hidden_states[-1].float()
|
||||
|
||||
target_ids = []
|
||||
target_ids_len = []
|
||||
hidden_states_select = []
|
||||
for batch_idx in range(labels_ids.shape[0]):
|
||||
beg_i = 0
|
||||
end_i = 0
|
||||
for token_idx in range(labels_ids.shape[1]):
|
||||
token_int = labels_ids[batch_idx, token_idx].item()
|
||||
if token_int == self.eos:
|
||||
target_ids_i = labels_ids[batch_idx, beg_i:end_i]
|
||||
target_ids_len_i = end_i - beg_i
|
||||
target_ids_len.append(target_ids_len_i)
|
||||
target_ids.append(target_ids_i)
|
||||
hidden_states_i = hidden_states[batch_idx, beg_i - 1 : end_i - 1, :]
|
||||
hidden_states_select.append(hidden_states_i)
|
||||
beg_i = end_i
|
||||
continue
|
||||
|
||||
end_i += 1
|
||||
if token_int <= 0:
|
||||
beg_i += 1
|
||||
|
||||
target_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
target_ids, batch_first=True, padding_value=-100
|
||||
)
|
||||
hidden_states_select = torch.nn.utils.rnn.pad_sequence(
|
||||
hidden_states_select, batch_first=True, padding_value=0.0
|
||||
)
|
||||
target_ids_len = torch.tensor(target_ids_len, dtype=torch.int32, device=input_ids.device)
|
||||
target_ids = target_ids.to(device=input_ids.device)
|
||||
hidden_states_select = hidden_states_select.to(device=input_ids.device)
|
||||
loss, logits, target, codec_lengths = self.nll(
|
||||
hidden_states_select, target_ids_len, codec, codec_len
|
||||
)
|
||||
|
||||
stats = {}
|
||||
with torch.no_grad():
|
||||
preds = torch.argmax(model_outputs.logits, -1)
|
||||
@ -2487,10 +2508,10 @@ class LLMASR5(nn.Module):
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = (
|
||||
speech_lengths.sum().item()
|
||||
* frontend.frame_shift
|
||||
* frontend.lfr_n
|
||||
/ 1000
|
||||
speech_lengths.sum().item()
|
||||
* frontend.frame_shift
|
||||
* frontend.lfr_n
|
||||
/ 1000
|
||||
)
|
||||
|
||||
if kwargs.get("permute", True):
|
||||
@ -2558,13 +2579,13 @@ class LLMASR5(nn.Module):
|
||||
return output
|
||||
|
||||
def inference_prepare(
|
||||
self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
meta_data = {}
|
||||
@ -2619,7 +2640,7 @@ class LLMASR5(nn.Module):
|
||||
|
||||
try:
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, :
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
except Exception as e:
|
||||
#
|
||||
@ -2632,20 +2653,20 @@ class LLMASR5(nn.Module):
|
||||
speech_token_len = encoder_out_lens[speech_idx].item()
|
||||
speech_token = encoder_out[speech_idx, :speech_token_len, :]
|
||||
inputs_embeds[
|
||||
batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, :
|
||||
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
|
||||
] = speech_token
|
||||
|
||||
speech_idx += 1
|
||||
return inputs_embeds, contents, batch, source_ids, meta_data
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = None,
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
|
||||
@ -2658,7 +2679,7 @@ class LLMASR5(nn.Module):
|
||||
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
|
||||
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
|
||||
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
|
||||
):
|
||||
label = contents["assistant"][-1]
|
||||
self.llm = self.llm.to(dtype_map[llm_dtype])
|
||||
@ -2688,7 +2709,7 @@ class LLMASR5(nn.Module):
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
||||
)
|
||||
|
||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]:]
|
||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
|
||||
response = tokenizer.batch_decode(
|
||||
preds,
|
||||
add_special_tokens=False,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user