diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py index e68d16b4f..419c8133a 100644 --- a/funasr/models/e2e_diar_sond.py +++ b/funasr/models/e2e_diar_sond.py @@ -85,12 +85,12 @@ class DiarSondModel(AbsESPnetModel): normalize_length=length_normalized_loss, ) self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss) - pse_embedding = self.generate_pse_embedding() - self.register_buffer("pse_embedding", pse_embedding) - power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float() - self.register_buffer("power_weight", power_weight) - int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int() - self.register_buffer("int_token_arr", int_token_arr) + self.pse_embedding = self.generate_pse_embedding() + # self.register_buffer("pse_embedding", pse_embedding) + self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float() + # self.register_buffer("power_weight", power_weight) + self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int() + # self.register_buffer("int_token_arr", int_token_arr) self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight self.inter_score_loss_weight = inter_score_loss_weight self.forward_steps = 0