This commit is contained in:
游雁 2024-06-08 19:45:15 +08:00
parent 3d5e19792c
commit 2191795f74

View File

@ -692,6 +692,7 @@ class LLMASR2(nn.Module):
batch_idx, :min_len, :
]
label = contents["assistant"][0]
if not kwargs.get("tearchforing", False):
generated_ids = self.llm.generate(
@ -704,7 +705,7 @@ class LLMASR2(nn.Module):
response = tokenizer.batch_decode(
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
)[0]
label = contents["assistant"][0]
loss = None
else:
@ -715,13 +716,13 @@ class LLMASR2(nn.Module):
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
)
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]]
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
response = tokenizer.batch_decode(
preds,
add_special_tokens=False,
skip_special_tokens=kwargs.get("skip_special_tokens", True),
)[0]
loss = model_outputs.loss
loss = model_outputs.loss.item()
ibest_writer = None
if kwargs.get("output_dir") is not None: