wav fronend

This commit is contained in:
游雁 2024-03-18 20:46:23 +08:00
parent bab0675c36
commit d3d2fe73c0
2 changed files with 4 additions and 2 deletions

View File

@ -75,6 +75,7 @@ def apply_lfr(inputs, lfr_m, lfr_n):
LFR_outputs = torch.vstack(LFR_inputs) LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32) return LFR_outputs.type(torch.float32)
@tables.register("frontend_classes", "wav_frontend")
@tables.register("frontend_classes", "WavFrontend") @tables.register("frontend_classes", "WavFrontend")
class WavFrontend(nn.Module): class WavFrontend(nn.Module):
"""Conventional frontend structure for ASR. """Conventional frontend structure for ASR.

View File

@ -146,7 +146,7 @@ class Trainer:
""" """
ckpt = os.path.join(resume_path, "model.pt") ckpt = os.path.join(resume_path, "model.pt")
if os.path.isfile(ckpt): if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt) checkpoint = torch.load(ckpt, map_location="cpu")
self.start_epoch = checkpoint['epoch'] + 1 self.start_epoch = checkpoint['epoch'] + 1
# self.model.load_state_dict(checkpoint['state_dict']) # self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint['state_dict'] src_state = checkpoint['state_dict']
@ -169,7 +169,8 @@ class Trainer:
print(f"Checkpoint loaded successfully from '{ckpt}'") print(f"Checkpoint loaded successfully from '{ckpt}'")
else: else:
print(f"No checkpoint found at '{ckpt}', does not resume status!") print(f"No checkpoint found at '{ckpt}', does not resume status!")
self.model.to(self.device)
if self.use_ddp or self.use_fsdp: if self.use_ddp or self.use_fsdp:
dist.barrier() dist.barrier()