mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
modify unit test for speech_diarization_sond-en-us-callhome-8k-n16k4-pytorch
This commit is contained in:
parent
c3ca7d963e
commit
7907c3df07
@ -573,19 +573,24 @@ class DiarTask(AbsTask):
|
||||
var_dict_torch = model.state_dict()
|
||||
var_dict_torch_update = dict()
|
||||
# speech encoder
|
||||
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
if model.encoder is not None:
|
||||
var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# speaker encoder
|
||||
var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
if model.speaker_encoder is not None:
|
||||
var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# cd scorer
|
||||
var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
if model.cd_scorer is not None:
|
||||
var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# ci scorer
|
||||
var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
if model.ci_scorer is not None:
|
||||
var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
# decoder
|
||||
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
if model.decoder is not None:
|
||||
var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
|
||||
var_dict_torch_update.update(var_dict_torch_update_local)
|
||||
|
||||
return var_dict_torch_update
|
||||
|
||||
Loading…
Reference in New Issue
Block a user