wav fronend

This commit is contained in:
游雁 2024-03-18 20:46:23 +08:00
parent bab0675c36
commit d3d2fe73c0
2 changed files with 4 additions and 2 deletions

View File

@ -75,6 +75,7 @@ def apply_lfr(inputs, lfr_m, lfr_n):
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
@tables.register("frontend_classes", "wav_frontend")
@tables.register("frontend_classes", "WavFrontend")
class WavFrontend(nn.Module):
"""Conventional frontend structure for ASR.

View File

@ -146,7 +146,7 @@ class Trainer:
"""
ckpt = os.path.join(resume_path, "model.pt")
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt)
checkpoint = torch.load(ckpt, map_location="cpu")
self.start_epoch = checkpoint['epoch'] + 1
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint['state_dict']
@ -169,7 +169,8 @@ class Trainer:
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
print(f"No checkpoint found at '{ckpt}', does not resume status!")
self.model.to(self.device)
if self.use_ddp or self.use_fsdp:
dist.barrier()