diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py index 94dae4d90..6a4eccefe 100644 --- a/funasr/models/llm_asr_nar/model.py +++ b/funasr/models/llm_asr_nar/model.py @@ -315,7 +315,7 @@ class LLMASRNAR(nn.Module): model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None) preds = torch.argmax(model_outputs.logits, -1) text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True) - text = text.split(': \n')[-1] + text = text[0].split(': \n')[-1] # preds = torch.argmax(model_outputs.logits, -1) ibest_writer = None diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index 7a8c37459..84c632027 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -90,10 +90,12 @@ def load_pretrained_model( if dst_prefix == "" and (src_prefix + k) in src_state.keys(): k_src = src_prefix + k - print(f"init param, map: {k} from {k_src} in ckpt") + if not k_src.startswith("module."): + print(f"init param, map: {k} from {k_src} in ckpt") elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys(): k_src = k.replace(dst_prefix, src_prefix, 1) - print(f"init param, map: {k} from {k_src} in ckpt") + if not k_src.startswith("module."): + print(f"init param, map: {k} from {k_src} in ckpt") if k_src in src_state.keys(): if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape: