This commit is contained in:
游雁 2024-06-12 14:29:28 +08:00
parent be26169447
commit 7d57828086

View File

@ -407,9 +407,19 @@ class LLMASR2(nn.Module):
audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
audio_encoder_output_size = audio_encoder.output_size()
freeze = audio_encoder_conf.get("freeze", True)
freeze_layer_num = audio_encoder_conf.get("freeze_layer_num", -1)
if freeze_layer_num > 0:
freeze_layer_num = range(freeze_layer_num)
else:
freeze_layer_num = [freeze_layer_num]
if freeze:
for name, param in audio_encoder.named_parameters():
param.requires_grad = False
idx = re.search(r"\.\d+\.", name)
if idx is not None:
beg, end = idx.regs[0]
layer_id = int(name[beg + 1 : end - 1])
if layer_id in freeze_layer_num:
param.requires_grad = False
audio_encoder.eval()
self.audio_encoder = audio_encoder