|
|
@@ -235,8 +235,11 @@ class Speech2DiarizationSOND:
|
|
|
new_seq.append(x)
|
|
|
else:
|
|
|
idx_list = np.where(seq < 2 ** vec_dim)[0]
|
|
|
- idx = np.abs(idx_list - i).argmin()
|
|
|
- new_seq.append(seq[idx_list[idx]])
|
|
|
+ if len(idx_list) > 0:
|
|
|
+ idx = np.abs(idx_list - i).argmin()
|
|
|
+ new_seq.append(seq[idx_list[idx]])
|
|
|
+ else:
|
|
|
+ new_seq.append(0)
|
|
|
return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
|
|
|
|
|
|
def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
|