This commit is contained in:
游雁 2024-06-13 16:07:49 +08:00
parent 5de8bfdcd8
commit 6ca0b838d4

View File

@ -410,17 +410,17 @@ class LLMASR2(nn.Module):
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
if freeze_layer_num > 0:
freeze_layer_num = range(freeze_layer_num)
# if freeze_layer_num > 0:
# freeze_layer_num = range(freeze_layer_num)
if freeze:
for name, param in audio_encoder.named_parameters():
if isinstance(freeze_layer_num, (list, tuple)):
if freeze_layer_num > 0:
idx = re.search(r"\.\d+\.", name)
if idx is not None:
beg, end = idx.regs[0]
layer_id = int(name[beg + 1 : end - 1])
if layer_id in freeze_layer_num:
if layer_id < freeze_layer_num:
param.requires_grad = False
else:
param.requires_grad = False