mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
decoding
This commit is contained in:
parent
a56980a26f
commit
407625a734
@ -124,6 +124,7 @@ def main(**kwargs):
|
||||
use_ddp=use_ddp,
|
||||
use_fsdp=use_fsdp,
|
||||
device=kwargs["device"],
|
||||
excludes=kwargs.get("excludes", None),
|
||||
output_dir=kwargs.get("output_dir", "./exp"),
|
||||
**kwargs.get("train_conf"),
|
||||
)
|
||||
|
||||
@ -147,6 +147,10 @@ class Trainer:
|
||||
|
||||
self.use_deepspeed = use_deepspeed
|
||||
self.deepspeed_config = kwargs.get("deepspeed_config", "")
|
||||
self.excludes = kwargs.get("excludes", None)
|
||||
if self.excludes is not None:
|
||||
if isinstance(self.excludes, str):
|
||||
self.excludes = self.excludes.split(",")
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
@ -440,6 +444,12 @@ class Trainer:
|
||||
src_state = checkpoint["state_dict"]
|
||||
dst_state = model.state_dict()
|
||||
for k in dst_state.keys():
|
||||
if excludes is not None:
|
||||
for k_ex in excludes:
|
||||
k_tmp = k.replace("module.", "")
|
||||
if k_tmp.startswith(k_ex):
|
||||
logging.info(f"key: {{k}} matching: {k_ex}, excluded")
|
||||
continue
|
||||
if not k.startswith("module.") and "module." + k in src_state.keys():
|
||||
k_ddp = "module." + k
|
||||
elif k.startswith("module.") and "module." + k not in src_state.keys():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user