mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix bug
This commit is contained in:
parent
3d5e19792c
commit
2191795f74
@ -692,6 +692,7 @@ class LLMASR2(nn.Module):
|
|||||||
batch_idx, :min_len, :
|
batch_idx, :min_len, :
|
||||||
]
|
]
|
||||||
|
|
||||||
|
label = contents["assistant"][0]
|
||||||
if not kwargs.get("tearchforing", False):
|
if not kwargs.get("tearchforing", False):
|
||||||
|
|
||||||
generated_ids = self.llm.generate(
|
generated_ids = self.llm.generate(
|
||||||
@ -704,7 +705,7 @@ class LLMASR2(nn.Module):
|
|||||||
response = tokenizer.batch_decode(
|
response = tokenizer.batch_decode(
|
||||||
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
|
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
|
||||||
)[0]
|
)[0]
|
||||||
label = contents["assistant"][0]
|
|
||||||
loss = None
|
loss = None
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@ -715,13 +716,13 @@ class LLMASR2(nn.Module):
|
|||||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]]
|
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
|
||||||
response = tokenizer.batch_decode(
|
response = tokenizer.batch_decode(
|
||||||
preds,
|
preds,
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
skip_special_tokens=kwargs.get("skip_special_tokens", True),
|
skip_special_tokens=kwargs.get("skip_special_tokens", True),
|
||||||
)[0]
|
)[0]
|
||||||
loss = model_outputs.loss
|
loss = model_outputs.loss.item()
|
||||||
|
|
||||||
ibest_writer = None
|
ibest_writer = None
|
||||||
if kwargs.get("output_dir") is not None:
|
if kwargs.get("output_dir") is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user