This commit is contained in:
游雁 2024-06-12 19:17:55 +08:00
parent 2518f03d20
commit e4a69d4768
2 changed files with 9 additions and 7 deletions

View File

@ -413,15 +413,16 @@ class LLMASR2(nn.Module):
if freeze: if freeze:
for name, param in audio_encoder.named_parameters(): for name, param in audio_encoder.named_parameters():
idx = re.search(r"\.\d+\.", name) if isinstance(freeze_layer_num, (list, tuple)):
if idx is not None: idx = re.search(r"\.\d+\.", name)
beg, end = idx.regs[0] if idx is not None:
layer_id = int(name[beg + 1 : end - 1]) beg, end = idx.regs[0]
if isinstance(freeze_layer_num, (list, tuple)): layer_id = int(name[beg + 1 : end - 1])
if layer_id in freeze_layer_num: if layer_id in freeze_layer_num:
param.requires_grad = False param.requires_grad = False
else: else:
param.requires_grad = False param.requires_grad = False
audio_encoder.eval() audio_encoder.eval()
self.audio_encoder = audio_encoder self.audio_encoder = audio_encoder

View File

@ -313,6 +313,7 @@ class Trainer:
state_dict = model.state_dict() state_dict = model.state_dict()
if self.effective_save_name_excludes is not None: if self.effective_save_name_excludes is not None:
logging.info(f"effective_save_name_excludes: {self.effective_save_name_excludes}")
dst_state_dict = {} dst_state_dict = {}
for k in state_dict.keys(): for k in state_dict.keys():
for k_ex in self.effective_save_name_excludes: for k_ex in self.effective_save_name_excludes: