diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py index 85cf8b9d6..8fd4e4638 100644 --- a/funasr/build_utils/build_model_from_file.py +++ b/funasr/build_utils/build_model_from_file.py @@ -74,7 +74,10 @@ def build_model_from_file( model_dict = torch.load(model_file, map_location=device) if task_name == "diar" and mode == "sond": model_dict = fileter_model_dict(model_dict, model.state_dict()) - model.load_state_dict(model_dict) + if task_name == "vad": + model.encoder.load_state_dict(model_dict) + else: + model.load_state_dict(model_dict) if model_name_pth is not None and not os.path.exists(model_name_pth): torch.save(model_dict, model_name_pth) logging.info("model_file is saved to pth: {}".format(model_name_pth))