diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py index 45e3d7a8a..e699dccb0 100644 --- a/funasr/tasks/diar.py +++ b/funasr/tasks/diar.py @@ -573,19 +573,24 @@ class DiarTask(AbsTask): var_dict_torch = model.state_dict() var_dict_torch_update = dict() # speech encoder - var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) - var_dict_torch_update.update(var_dict_torch_update_local) + if model.encoder is not None: + var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) # speaker encoder - var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) - var_dict_torch_update.update(var_dict_torch_update_local) + if model.speaker_encoder is not None: + var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) # cd scorer - var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) - var_dict_torch_update.update(var_dict_torch_update_local) + if model.cd_scorer is not None: + var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) # ci scorer - var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) - var_dict_torch_update.update(var_dict_torch_update_local) + if model.ci_scorer is not None: + var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) # decoder - var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) - var_dict_torch_update.update(var_dict_torch_update_local) + if model.decoder is not None: + var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch) + var_dict_torch_update.update(var_dict_torch_update_local) return var_dict_torch_update