This commit is contained in:
游雁 2024-06-08 19:45:15 +08:00
parent 3d5e19792c
commit 2191795f74

View File

@ -692,6 +692,7 @@ class LLMASR2(nn.Module):
batch_idx, :min_len, : batch_idx, :min_len, :
] ]
label = contents["assistant"][0]
if not kwargs.get("tearchforing", False): if not kwargs.get("tearchforing", False):
generated_ids = self.llm.generate( generated_ids = self.llm.generate(
@ -704,7 +705,7 @@ class LLMASR2(nn.Module):
response = tokenizer.batch_decode( response = tokenizer.batch_decode(
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
)[0] )[0]
label = contents["assistant"][0]
loss = None loss = None
else: else:
@ -715,13 +716,13 @@ class LLMASR2(nn.Module):
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
) )
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]] preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
response = tokenizer.batch_decode( response = tokenizer.batch_decode(
preds, preds,
add_special_tokens=False, add_special_tokens=False,
skip_special_tokens=kwargs.get("skip_special_tokens", True), skip_special_tokens=kwargs.get("skip_special_tokens", True),
)[0] )[0]
loss = model_outputs.loss loss = model_outputs.loss.item()
ibest_writer = None ibest_writer = None
if kwargs.get("output_dir") is not None: if kwargs.get("output_dir") is not None: