志浩 3 лет назад
Родитель
Сommit
97f8201138
1 измененных файлов с 6 добавлено и 6 удалено
  1. 6 6
      funasr/models/e2e_diar_sond.py

+ 6 - 6
funasr/models/e2e_diar_sond.py

@@ -85,12 +85,12 @@ class DiarSondModel(AbsESPnetModel):
             normalize_length=length_normalized_loss,
         )
         self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
-        pse_embedding = self.generate_pse_embedding()
-        self.register_buffer("pse_embedding", pse_embedding)
-        power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
-        self.register_buffer("power_weight", power_weight)
-        int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
-        self.register_buffer("int_token_arr", int_token_arr)
+        self.pse_embedding = self.generate_pse_embedding()
+        # self.register_buffer("pse_embedding", pse_embedding)
+        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
+        # self.register_buffer("power_weight", power_weight)
+        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
+        # self.register_buffer("int_token_arr", int_token_arr)
         self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
         self.inter_score_loss_weight = inter_score_loss_weight
         self.forward_steps = 0