From d3d2fe73c08ee51d3a44d7ffb7b31eff32b60404 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 18 Mar 2024 20:46:23 +0800 Subject: [PATCH] wav fronend --- funasr/frontends/wav_frontend.py | 1 + funasr/train_utils/trainer.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py index c6e03e86e..afa7421ca 100644 --- a/funasr/frontends/wav_frontend.py +++ b/funasr/frontends/wav_frontend.py @@ -75,6 +75,7 @@ def apply_lfr(inputs, lfr_m, lfr_n): LFR_outputs = torch.vstack(LFR_inputs) return LFR_outputs.type(torch.float32) +@tables.register("frontend_classes", "wav_frontend") @tables.register("frontend_classes", "WavFrontend") class WavFrontend(nn.Module): """Conventional frontend structure for ASR. diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py index a00b3de0f..14abd6cc8 100644 --- a/funasr/train_utils/trainer.py +++ b/funasr/train_utils/trainer.py @@ -146,7 +146,7 @@ class Trainer: """ ckpt = os.path.join(resume_path, "model.pt") if os.path.isfile(ckpt): - checkpoint = torch.load(ckpt) + checkpoint = torch.load(ckpt, map_location="cpu") self.start_epoch = checkpoint['epoch'] + 1 # self.model.load_state_dict(checkpoint['state_dict']) src_state = checkpoint['state_dict'] @@ -169,7 +169,8 @@ class Trainer: print(f"Checkpoint loaded successfully from '{ckpt}'") else: print(f"No checkpoint found at '{ckpt}', does not resume status!") - + + self.model.to(self.device) if self.use_ddp or self.use_fsdp: dist.barrier()