From e4a69d4768674e57faf4a08eecca2fce88d3e190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Wed, 12 Jun 2024 19:17:55 +0800 Subject: [PATCH] decoding --- funasr/models/llm_asr/model.py | 15 ++++++++------- funasr/train_utils/trainer_ds.py | 1 + 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index fb0bee3ca..14837b936 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -413,15 +413,16 @@ class LLMASR2(nn.Module): if freeze: for name, param in audio_encoder.named_parameters(): - idx = re.search(r"\.\d+\.", name) - if idx is not None: - beg, end = idx.regs[0] - layer_id = int(name[beg + 1 : end - 1]) - if isinstance(freeze_layer_num, (list, tuple)): + if isinstance(freeze_layer_num, (list, tuple)): + idx = re.search(r"\.\d+\.", name) + if idx is not None: + beg, end = idx.regs[0] + layer_id = int(name[beg + 1 : end - 1]) if layer_id in freeze_layer_num: param.requires_grad = False - else: - param.requires_grad = False + else: + param.requires_grad = False + audio_encoder.eval() self.audio_encoder = audio_encoder diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py index b2d7b95d8..22be426e9 100644 --- a/funasr/train_utils/trainer_ds.py +++ b/funasr/train_utils/trainer_ds.py @@ -313,6 +313,7 @@ class Trainer: state_dict = model.state_dict() if self.effective_save_name_excludes is not None: + logging.info(f"effective_save_name_excludes: {self.effective_save_name_excludes}") dst_state_dict = {} for k in state_dict.keys(): for k_ex in self.effective_save_name_excludes: