|
@@ -85,12 +85,12 @@ class DiarSondModel(AbsESPnetModel):
|
|
|
normalize_length=length_normalized_loss,
|
|
normalize_length=length_normalized_loss,
|
|
|
)
|
|
)
|
|
|
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
|
|
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
|
|
|
- pse_embedding = self.generate_pse_embedding()
|
|
|
|
|
- self.register_buffer("pse_embedding", pse_embedding)
|
|
|
|
|
- power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
|
|
|
|
|
- self.register_buffer("power_weight", power_weight)
|
|
|
|
|
- int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
|
|
|
|
|
- self.register_buffer("int_token_arr", int_token_arr)
|
|
|
|
|
|
|
+ self.pse_embedding = self.generate_pse_embedding()
|
|
|
|
|
+ # self.register_buffer("pse_embedding", pse_embedding)
|
|
|
|
|
+ self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
|
|
|
|
|
+ # self.register_buffer("power_weight", power_weight)
|
|
|
|
|
+ self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
|
|
|
|
|
+ # self.register_buffer("int_token_arr", int_token_arr)
|
|
|
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
|
|
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
|
|
|
self.inter_score_loss_weight = inter_score_loss_weight
|
|
self.inter_score_loss_weight = inter_score_loss_weight
|
|
|
self.forward_steps = 0
|
|
self.forward_steps = 0
|