diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py index c6e03e86e..afa7421ca 100644 --- a/funasr/frontends/wav_frontend.py +++ b/funasr/frontends/wav_frontend.py @@ -75,6 +75,7 @@ def apply_lfr(inputs, lfr_m, lfr_n): LFR_outputs = torch.vstack(LFR_inputs) return LFR_outputs.type(torch.float32) +@tables.register("frontend_classes", "wav_frontend") @tables.register("frontend_classes", "WavFrontend") class WavFrontend(nn.Module): """Conventional frontend structure for ASR. diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index a00b3de0f..14abd6cc8 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -146,7 +146,7 @@ class Trainer: """ ckpt = os.path.join(resume_path, "model.pt") if os.path.isfile(ckpt): - checkpoint = torch.load(ckpt) + checkpoint = torch.load(ckpt, map_location="cpu") self.start_epoch = checkpoint['epoch'] + 1 # self.model.load_state_dict(checkpoint['state_dict']) src_state = checkpoint['state_dict'] @@ -169,7 +169,8 @@ class Trainer: print(f"Checkpoint loaded successfully from '{ckpt}'") else: print(f"No checkpoint found at '{ckpt}', does not resume status!") - + + self.model.to(self.device) if self.use_ddp or self.use_fsdp: dist.barrier()