From 2191795f742063b1c0a394fc2a65898445ccce65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Sat, 8 Jun 2024 19:45:15 +0800 Subject: [PATCH] fix bug --- funasr/models/llm_asr/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 697f78dc7..f8c3efc77 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -692,6 +692,7 @@ class LLMASR2(nn.Module): batch_idx, :min_len, : ] + label = contents["assistant"][0] if not kwargs.get("tearchforing", False): generated_ids = self.llm.generate( @@ -704,7 +705,7 @@ class LLMASR2(nn.Module): response = tokenizer.batch_decode( generated_ids, skip_special_tokens=kwargs.get("skip_special_tokens", True) )[0] - label = contents["assistant"][0] + loss = None else: @@ -715,13 +716,13 @@ class LLMASR2(nn.Module): inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids ) - preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1]] + preds = torch.argmax(model_outputs.logits, -1)[:, source_ids.shape[1] :] response = tokenizer.batch_decode( preds, add_special_tokens=False, skip_special_tokens=kwargs.get("skip_special_tokens", True), )[0] - loss = model_outputs.loss + loss = model_outputs.loss.item() ibest_writer = None if kwargs.get("output_dir") is not None: