This commit is contained in:
游雁 2024-06-12 17:17:03 +08:00
parent 765e6371bb
commit 9afcf0ea7d

View File

@ -410,15 +410,17 @@ class LLMASR2(nn.Module):
freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
if freeze_layer_num > 0:
freeze_layer_num = range(freeze_layer_num)
else:
freeze_layer_num = [freeze_layer_num]
if freeze:
for name, param in audio_encoder.named_parameters():
idx = re.search(r"\.\d+\.", name)
if idx is not None:
beg, end = idx.regs[0]
layer_id = int(name[beg + 1 : end - 1])
if layer_id in freeze_layer_num:
if isinstance(freeze_layer_num, (list, tuple)):
if layer_id in freeze_layer_num:
param.requires_grad = False
else:
param.requires_grad = False
audio_encoder.eval()