mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
fixbug sond initial
This commit is contained in:
parent
88efde8799
commit
97f8201138
@ -85,12 +85,12 @@ class DiarSondModel(AbsESPnetModel):
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
|
||||
pse_embedding = self.generate_pse_embedding()
|
||||
self.register_buffer("pse_embedding", pse_embedding)
|
||||
power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
|
||||
self.register_buffer("power_weight", power_weight)
|
||||
int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
|
||||
self.register_buffer("int_token_arr", int_token_arr)
|
||||
self.pse_embedding = self.generate_pse_embedding()
|
||||
# self.register_buffer("pse_embedding", pse_embedding)
|
||||
self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
|
||||
# self.register_buffer("power_weight", power_weight)
|
||||
self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
|
||||
# self.register_buffer("int_token_arr", int_token_arr)
|
||||
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
|
||||
self.inter_score_loss_weight = inter_score_loss_weight
|
||||
self.forward_steps = 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user