This commit is contained in:
游雁 2024-07-02 14:10:57 +08:00
parent b6aad84db6
commit 4784baf2af

View File

@ -1408,7 +1408,7 @@ class LLMASR4(nn.Module):
return results, meta_data
@tables.register("model_classes", "LLMASR5")
# @tables.register("model_classes", "LLMASR5")
class LLMASR5(nn.Module):
""" """
@ -2011,41 +2011,19 @@ class LLMASR5(nn.Module):
""" """
def __init__(
self,
specaug: str = None,
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
ctc_weight: float = 0.5,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
blank_id: int = 0,
sos: int = 1,
eos: int = 2,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
report_cer: bool = True,
report_wer: bool = True,
sym_space: str = "<space>",
sym_blank: str = "<blank>",
# extract_feats_in_collect_stats: bool = True,
share_embedding: bool = False,
# preencoder: Optional[AbsPreEncoder] = None,
# postencoder: Optional[AbsPostEncoder] = None,
audio_decoder: str = None,
audio_decoder_conf: dict = None,
**kwargs,
self,
audio_encoder: str = None,
audio_encoder_conf: dict = None,
audio_adaptor: str = None,
audio_adaptor_conf: dict = None,
llm: str = None,
llm_conf: dict = None,
input_size: int = 80,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
audio_decoder: str = None,
audio_decoder_conf: dict = None,
**kwargs,
):
super().__init__()
@ -2082,7 +2060,7 @@ class LLMASR5(nn.Module):
idx = re.search(r"\.\d+\.", name)
if idx is not None:
beg, end = idx.regs[0]
layer_id = int(name[beg + 1: end - 1])
layer_id = int(name[beg + 1 : end - 1])
if layer_id < freeze_layer_num:
param.requires_grad = False
elif "ln_post." not in name:
@ -2134,6 +2112,8 @@ class LLMASR5(nn.Module):
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
self.eos = kwargs.get("eos", 151645)
# audio decoder related
self.audio_decoder = self.build_audio_decoder(name=audio_decoder, conf=audio_decoder_conf)
self.audio_decoder_in_proj = torch.nn.Linear(llm_dim, self.audio_decoder.embed_unit)
@ -2148,16 +2128,12 @@ class LLMASR5(nn.Module):
def build_audio_decoder(self, name, conf):
if name == "transformer":
from funasr.models.llm_asr.transformer_lm import TransformerEmbedLM
if "text_vocab_size" in conf:
lm_model = TransformerEmbedLM(
vocab_size=self.lm_out_voc_size,
**conf
)
lm_model = TransformerEmbedLM(vocab_size=self.lm_out_voc_size, **conf)
else:
lm_model = TransformerEmbedLM(
vocab_size=self.lm_out_voc_size,
text_vocab_size=self.lm_out_voc_size,
**conf
vocab_size=self.lm_out_voc_size, text_vocab_size=self.lm_out_voc_size, **conf
)
else:
raise TypeError(f"Unknown codec decoder type {name}")
@ -2175,30 +2151,35 @@ class LLMASR5(nn.Module):
return self.codec_embedder(codec * mask).sum(dim=-2) * mask
def prepare_audio_decoder_io(
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
codec: Optional[torch.Tensor] = None,
codec_lengths: Optional[torch.Tensor] = None,
need_targets: bool = True,
self,
text: torch.Tensor,
text_lengths: torch.Tensor,
codec: Optional[torch.Tensor] = None,
codec_lengths: Optional[torch.Tensor] = None,
need_targets: bool = True,
):
"""build inputs and targets for language model
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length, Dim)
text_lengths: (Batch,)
codec: (Batch, Length)
codec_lengths: (Batch,)
need_targets: bool, whether provide targets
"""
Normally, this function is called in batchify_nll.
Args:
text: (Batch, Length, Dim)
text_lengths: (Batch,)
codec: (Batch, Length)
codec_lengths: (Batch,)
need_targets: bool, whether provide targets
"""
if need_targets:
assert codec is not None and codec_lengths is not None, \
"need_target=True, but codec or codec_length is None"
assert (
codec is not None and codec_lengths is not None
), "need_target=True, but codec or codec_length is None"
sos_eos_emb = self.audio_decoder_embedding(torch.tensor([self.ad_sos_eos], dtype=torch.int64, device=text.device))
task_id_emb = self.audio_decoder_embedding(torch.tensor([self.ad_task_id], dtype=torch.int64, device=text.device))
sos_eos_emb = self.audio_decoder_embedding(
torch.tensor([self.ad_sos_eos], dtype=torch.int64, device=text.device)
)
task_id_emb = self.audio_decoder_embedding(
torch.tensor([self.ad_task_id], dtype=torch.int64, device=text.device)
)
codec_emb = None
if codec is not None and codec_lengths is not None:
codec_emb = self.calc_dense_vector(codec, codec_lengths)
@ -2206,7 +2187,7 @@ class LLMASR5(nn.Module):
for i, text_len in enumerate(text_lengths):
one_input = [sos_eos_emb, text[i, :text_len], task_id_emb]
if codec_emb is not None:
one_input.append(codec_emb[i, :codec_lengths[i]])
one_input.append(codec_emb[i, : codec_lengths[i]])
inputs_list.append(torch.cat(one_input, dim=0))
llm_inputs = pad_list(inputs_list, 0.0)
llm_lengths = text_lengths + 2
@ -2217,7 +2198,9 @@ class LLMASR5(nn.Module):
return llm_inputs, llm_lengths
bb, tt = text.shape[0], codec_lengths.max() + 1
llm_targets = -1 * torch.ones([bb, tt, self.predict_nq], dtype=torch.int64, device=text.device)
llm_targets = -1 * torch.ones(
[bb, tt, self.predict_nq], dtype=torch.int64, device=text.device
)
for i, codec_len in enumerate(codec_lengths):
llm_targets[i, :codec_len] = codec[i, :codec_len]
llm_targets[i, codec_len] = self.codebook_size + self.sos_eos
@ -2242,36 +2225,33 @@ class LLMASR5(nn.Module):
"""
batch_size = text.size(0)
# For data parallel
text = text[:, :text_lengths.max()]
codec = codec[:, :codec_lengths.max()]
text = text[:, : text_lengths.max()]
codec = codec[:, : codec_lengths.max()]
text = self.audio_decoder_in_proj(text)
# build inputs and targets for language model
with autocast(False):
(sequence, target), (x_lengths, y_lengths) = self.prepare_audio_decoder_io(
text, text_lengths,
codec, codec_lengths,
need_targets=True
text, text_lengths, codec, codec_lengths, need_targets=True
)
# 2a. Forward Language model
# x: (Batch, Length) -> y: (Batch, Length, NVocab)
sequence = sequence[:, :x_lengths.max()]
target = target[:, :y_lengths.max()]
y, _ = self.audio_decoder(sequence, x_lengths, text_lengths+1)
sequence = sequence[:, : x_lengths.max()]
target = target[:, : y_lengths.max()]
y, _ = self.audio_decoder(sequence, x_lengths, text_lengths + 1)
bb, tt = y.shape[0], y.shape[1]
y = y.reshape(bb, tt, self.predict_nq, -1)
# 2b. Extract real logits
logits_list = []
for i, (text_len, codec_len) in enumerate(zip(text_lengths, codec_lengths)):
logits_list.append(y[i, text_len + 1:text_len + 2 + codec_len])
logits_list.append(y[i, text_len + 1 : text_len + 2 + codec_len])
logits = pad_list(logits_list, 0.0)
# 3. Calc negative log likelihood
tt = logits.shape[1]
nll = self.criterion_ce(
logits.reshape(bb, tt * self.predict_nq, -1),
target.reshape(bb, tt * self.predict_nq)
logits.reshape(bb, tt * self.predict_nq, -1), target.reshape(bb, tt * self.predict_nq)
)
nll = nll.sum(-1)
# nll: (BxL,) -> (BxL,)
@ -2279,18 +2259,18 @@ class LLMASR5(nn.Module):
# nll: (BxL,) -> (B, L)
nll = nll.reshape(batch_size, -1).reshape(batch_size, tt, self.predict_nq)
return nll, logits, target, codec_lengths+1
return nll, logits, target, codec_lengths + 1
def forward(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
fbank_beg: torch.Tensor,
fbank_mask: torch.Tensor,
**kwargs,
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
fbank_beg: torch.Tensor,
fbank_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Encoder + Decoder + Calc loss
Args:
@ -2334,7 +2314,7 @@ class LLMASR5(nn.Module):
try:
inputs_embeds[
batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, :
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
] = speech_token
except Exception as e:
#
@ -2347,13 +2327,13 @@ class LLMASR5(nn.Module):
speech_token_len = encoder_out_lens[speech_idx].item()
speech_token = encoder_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, :
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
] = speech_token
speech_idx += 1
with torch.cuda.amp.autocast(
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
enabled=True if self.llm_dtype != "fp32" else False, dtype=dtype_map[self.llm_dtype]
):
labels_ids[labels_ids == -1] = -100
attention_mask[attention_mask < 0] = 0
@ -2364,6 +2344,47 @@ class LLMASR5(nn.Module):
)
loss = model_outputs.loss
codec = kwargs.get("codec")
codec_len = kwargs.get("codec_len")
if len(codec_len.size()) > 1:
codec_len = codec_len[:, 0]
hidden_states = model_outputs.hidden_states[-1].float()
target_ids = []
target_ids_len = []
hidden_states_select = []
for batch_idx in range(labels_ids.shape[0]):
beg_i = 0
end_i = 0
for token_idx in range(labels_ids.shape[1]):
token_int = labels_ids[batch_idx, token_idx].item()
if token_int == self.eos:
target_ids_i = labels_ids[batch_idx, beg_i:end_i]
target_ids_len_i = end_i - beg_i
target_ids_len.append(target_ids_len_i)
target_ids.append(target_ids_i)
hidden_states_i = hidden_states[batch_idx, beg_i - 1 : end_i - 1, :]
hidden_states_select.append(hidden_states_i)
beg_i = end_i
continue
end_i += 1
if token_int <= 0:
beg_i += 1
target_ids = torch.nn.utils.rnn.pad_sequence(
target_ids, batch_first=True, padding_value=-100
)
hidden_states_select = torch.nn.utils.rnn.pad_sequence(
hidden_states_select, batch_first=True, padding_value=0.0
)
target_ids_len = torch.tensor(target_ids_len, dtype=torch.int32, device=input_ids.device)
target_ids = target_ids.to(device=input_ids.device)
hidden_states_select = hidden_states_select.to(device=input_ids.device)
loss, logits, target, codec_lengths = self.nll(
hidden_states_select, target_ids_len, codec, codec_len
)
stats = {}
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
@ -2487,10 +2508,10 @@ class LLMASR5(nn.Module):
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = (
speech_lengths.sum().item()
* frontend.frame_shift
* frontend.lfr_n
/ 1000
speech_lengths.sum().item()
* frontend.frame_shift
* frontend.lfr_n
/ 1000
)
if kwargs.get("permute", True):
@ -2558,13 +2579,13 @@ class LLMASR5(nn.Module):
return output
def inference_prepare(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
meta_data = {}
@ -2619,7 +2640,7 @@ class LLMASR5(nn.Module):
try:
inputs_embeds[
batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, :
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
] = speech_token
except Exception as e:
#
@ -2632,20 +2653,20 @@ class LLMASR5(nn.Module):
speech_token_len = encoder_out_lens[speech_idx].item()
speech_token = encoder_out[speech_idx, :speech_token_len, :]
inputs_embeds[
batch_idx, fbank_beg_idx: fbank_beg_idx + speech_token_len, :
batch_idx, fbank_beg_idx : fbank_beg_idx + speech_token_len, :
] = speech_token
speech_idx += 1
return inputs_embeds, contents, batch, source_ids, meta_data
def inference(
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
@ -2658,7 +2679,7 @@ class LLMASR5(nn.Module):
llm_dtype = "bf16" if kwargs.get("bf16", False) else llm_dtype
with torch.cuda.amp.autocast(
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
enabled=True if llm_dtype != "fp32" else False, dtype=dtype_map[llm_dtype]
):
label = contents["assistant"][-1]
self.llm = self.llm.to(dtype_map[llm_dtype])
@ -2688,7 +2709,7 @@ class LLMASR5(nn.Module):
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
)
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]:]
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
response = tokenizer.batch_decode(
preds,
add_special_tokens=False,