From 6ca0b838d48106030984eacf204e8f1f2f05985b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 13 Jun 2024 16:07:49 +0800 Subject: [PATCH] decoding --- funasr/models/llm_asr/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py index 15969e35b..85351b722 100644 --- a/funasr/models/llm_asr/model.py +++ b/funasr/models/llm_asr/model.py @@ -410,17 +410,17 @@ class LLMASR2(nn.Module): audio_encoder_output_size = audio_encoder.output_size() freeze = audio_encoder_conf.get("freeze", True) freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1)) - if freeze_layer_num > 0: - freeze_layer_num = range(freeze_layer_num) + # if freeze_layer_num > 0: + # freeze_layer_num = range(freeze_layer_num) if freeze: for name, param in audio_encoder.named_parameters(): - if isinstance(freeze_layer_num, (list, tuple)): + if freeze_layer_num > 0: idx = re.search(r"\.\d+\.", name) if idx is not None: beg, end = idx.regs[0] layer_id = int(name[beg + 1 : end - 1]) - if layer_id in freeze_layer_num: + if layer_id < freeze_layer_num: param.requires_grad = False else: param.requires_grad = False