fixbug sond initial

This commit is contained in:
志浩 2023-02-27 15:03:07 +08:00
parent 88efde8799
commit 97f8201138

View File

@ -85,12 +85,12 @@ class DiarSondModel(AbsESPnetModel):
normalize_length=length_normalized_loss,
)
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
pse_embedding = self.generate_pse_embedding()
self.register_buffer("pse_embedding", pse_embedding)
power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
self.register_buffer("power_weight", power_weight)
int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
self.register_buffer("int_token_arr", int_token_arr)
self.pse_embedding = self.generate_pse_embedding()
# self.register_buffer("pse_embedding", pse_embedding)
self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
# self.register_buffer("power_weight", power_weight)
self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
# self.register_buffer("int_token_arr", int_token_arr)
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
self.inter_score_loss_weight = inter_score_loss_weight
self.forward_steps = 0