mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fix bug
This commit is contained in:
parent
3d5e19792c
commit
2191795f74
@ -692,6 +692,7 @@ class LLMASR2(nn.Module):
|
||||
batch_idx, :min_len, :
|
||||
]
|
||||
|
||||
label = contents["assistant"][0]
|
||||
if not kwargs.get("tearchforing", False):
|
||||
|
||||
generated_ids = self.llm.generate(
|
||||
@ -704,7 +705,7 @@ class LLMASR2(nn.Module):
|
||||
response = tokenizer.batch_decode(
|
||||
generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True)
|
||||
)[0]
|
||||
label = contents["assistant"][0]
|
||||
|
||||
loss = None
|
||||
else:
|
||||
|
||||
@ -715,13 +716,13 @@ class LLMASR2(nn.Module):
|
||||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
|
||||
)
|
||||
|
||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]]
|
||||
preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :]
|
||||
response = tokenizer.batch_decode(
|
||||
preds,
|
||||
add_special_tokens=False,
|
||||
skip_special_tokens=kwargs.get("skip_special_tokens", True),
|
||||
)[0]
|
||||
loss = model_outputs.loss
|
||||
loss = model_outputs.loss.item()
|
||||
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user