mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding
This commit is contained in:
parent
5de8bfdcd8
commit
6ca0b838d4
@ -410,17 +410,17 @@ class LLMASR2(nn.Module):
|
||||
audio_encoder_output_size = audio_encoder.output_size()
|
||||
freeze = audio_encoder_conf.get("freeze", True)
|
||||
freeze_layer_num = int(audio_encoder_conf.get("freeze_layer_num", -1))
|
||||
if freeze_layer_num > 0:
|
||||
freeze_layer_num = range(freeze_layer_num)
|
||||
# if freeze_layer_num > 0:
|
||||
# freeze_layer_num = range(freeze_layer_num)
|
||||
|
||||
if freeze:
|
||||
for name, param in audio_encoder.named_parameters():
|
||||
if isinstance(freeze_layer_num, (list, tuple)):
|
||||
if freeze_layer_num > 0:
|
||||
idx = re.search(r"\.\d+\.", name)
|
||||
if idx is not None:
|
||||
beg, end = idx.regs[0]
|
||||
layer_id = int(name[beg + 1 : end - 1])
|
||||
if layer_id in freeze_layer_num:
|
||||
if layer_id < freeze_layer_num:
|
||||
param.requires_grad = False
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user