mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
auto frontend
This commit is contained in:
parent
162efb747f
commit
526c810bd7
@ -155,6 +155,9 @@ class OpenAIDataset(torch.utils.data.Dataset):
|
||||
fbank_beg.append(fbank_beg_i)
|
||||
|
||||
if len(input_ids) > self.max_token_length:
|
||||
logging.info(
|
||||
f"input_ids > max_token_length: {len(input_ids)}>{self.max_token_length}, {item}"
|
||||
)
|
||||
badcase_flag = True
|
||||
if badcase_flag:
|
||||
continue
|
||||
|
||||
@ -485,13 +485,24 @@ class LLMASR2(nn.Module):
|
||||
# _, l, _ = encoder_out.shape
|
||||
for batch_idx in range(batch_size):
|
||||
|
||||
l = fbank_fake_lens[batch_idx].item()
|
||||
fbank_fake_len = fbank_fake_lens[batch_idx].item()
|
||||
fbank_beg_idx = fbank_beg[batch_idx, 0].item()
|
||||
min_len = min(l, inputs_embeds.shape[1] - fbank_beg_idx)
|
||||
min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
|
||||
try:
|
||||
inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
|
||||
batch_idx, :min_len, :
|
||||
]
|
||||
except Exception as e:
|
||||
logging.error(f"{str(e)}, {traceback.format_exc()}")
|
||||
logging.info(
|
||||
f"batch_idx: {batch_idx}, inputs_embeds: {inputs_embeds.shape}, fbank_beg_idx: {fbank_beg_idx}, min_len: {min_len}, fbank_fake_len: {fbank_fake_len}"
|
||||
)
|
||||
fbank_fake_len = encoder_out_lens[batch_idx].item()
|
||||
min_len = min(fbank_fake_len, inputs_embeds.shape[1] - fbank_beg_idx)
|
||||
inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
|
||||
batch_idx, :min_len, :
|
||||
]
|
||||
|
||||
inputs_embeds[batch_idx, fbank_beg_idx : fbank_beg_idx + min_len, :] = encoder_out[
|
||||
batch_idx, :min_len, :
|
||||
]
|
||||
labels_ids[labels_ids == -1] = -100
|
||||
model_outputs = self.llm(
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
||||
|
||||
Loading…
Reference in New Issue
Block a user