mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
wav fronend
This commit is contained in:
parent
bab0675c36
commit
d3d2fe73c0
@ -75,6 +75,7 @@ def apply_lfr(inputs, lfr_m, lfr_n):
|
||||
LFR_outputs = torch.vstack(LFR_inputs)
|
||||
return LFR_outputs.type(torch.float32)
|
||||
|
||||
@tables.register("frontend_classes", "wav_frontend")
|
||||
@tables.register("frontend_classes", "WavFrontend")
|
||||
class WavFrontend(nn.Module):
|
||||
"""Conventional frontend structure for ASR.
|
||||
|
||||
@ -146,7 +146,7 @@ class Trainer:
|
||||
"""
|
||||
ckpt = os.path.join(resume_path, "model.pt")
|
||||
if os.path.isfile(ckpt):
|
||||
checkpoint = torch.load(ckpt)
|
||||
checkpoint = torch.load(ckpt, map_location="cpu")
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
# self.model.load_state_dict(checkpoint['state_dict'])
|
||||
src_state = checkpoint['state_dict']
|
||||
@ -169,7 +169,8 @@ class Trainer:
|
||||
print(f"Checkpoint loaded successfully from '{ckpt}'")
|
||||
else:
|
||||
print(f"No checkpoint found at '{ckpt}', does not resume status!")
|
||||
|
||||
|
||||
self.model.to(self.device)
|
||||
if self.use_ddp or self.use_fsdp:
|
||||
dist.barrier()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user