From d0f353205297be80ecd0b84eb6280c7865f4bcf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 28 Feb 2024 20:40:35 +0800 Subject: [PATCH] init param --- funasr/models/llm_asr_nar/model.py | 2 +- funasr/train_utils/load_pretrained_model.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py index 94dae4d90..6a4eccefe 100644 --- a/funasr/models/llm_asr_nar/model.py +++ b/funasr/models/llm_asr_nar/model.py @@ -315,7 +315,7 @@ class LLMASRNAR(nn.Module): model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None) preds = torch.argmax(model_outputs.logits, -1) text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True) - text = text.split(': \n')[-1] + text = text[0].split(': \n')[-1] # preds = torch.argmax(model_outputs.logits, -1) ibest_writer = None diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index 7a8c37459..84c632027 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -90,10 +90,12 @@ def load_pretrained_model( if dst_prefix == "" and (src_prefix + k) in src_state.keys(): k_src = src_prefix + k - print(f"init param, map: {k} from {k_src} in ckpt") + if not k_src.startswith("module."): + print(f"init param, map: {k} from {k_src} in ckpt") elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys(): k_src = k.replace(dst_prefix, src_prefix, 1) - print(f"init param, map: {k} from {k_src} in ckpt") + if not k_src.startswith("module."): + print(f"init param, map: {k} from {k_src} in ckpt") if k_src in src_state.keys(): if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape: