|
|
@@ -52,15 +52,15 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|
|
|
|
|
super().__init__()
|
|
|
self.frontend = frontend
|
|
|
- self.encoder = encoder
|
|
|
- self.encoder_decoder_attractor = encoder_decoder_attractor
|
|
|
+ self.enc = encoder
|
|
|
+ self.eda = encoder_decoder_attractor
|
|
|
self.attractor_loss_weight = attractor_loss_weight
|
|
|
self.max_n_speaker = max_n_speaker
|
|
|
if mapping_dict is None:
|
|
|
mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
|
|
|
self.mapping_dict = mapping_dict
|
|
|
# PostNet
|
|
|
- self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
|
|
|
+ self.postnet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
|
|
|
self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
|
|
|
|
|
|
def forward_encoder(self, xs, ilens):
|
|
|
@@ -68,7 +68,7 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|
|
pad_shape = xs.shape
|
|
|
xs_mask = [torch.ones(ilen).to(xs.device) for ilen in ilens]
|
|
|
xs_mask = torch.nn.utils.rnn.pad_sequence(xs_mask, batch_first=True, padding_value=0).unsqueeze(-2)
|
|
|
- emb = self.encoder(xs, xs_mask)
|
|
|
+ emb = self.enc(xs, xs_mask)
|
|
|
emb = torch.split(emb.view(pad_shape[0], pad_shape[1], -1), 1, dim=0)
|
|
|
emb = [e[0][:ilen] for e, ilen in zip(emb, ilens)]
|
|
|
return emb
|
|
|
@@ -77,7 +77,7 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|
|
maxlen = torch.max(ilens).to(torch.int).item()
|
|
|
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
|
|
|
logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False)
|
|
|
- outputs, (_, _) = self.PostNet(logits)
|
|
|
+ outputs, (_, _) = self.postnet(logits)
|
|
|
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
|
|
|
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
|
|
|
outputs = [self.output_layer(output) for output in outputs]
|
|
|
@@ -112,7 +112,7 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|
|
text = text[:, : text_lengths.max()]
|
|
|
|
|
|
# 1. Encoder
|
|
|
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
|
|
|
+ encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
|
|
|
intermediate_outs = None
|
|
|
if isinstance(encoder_out, tuple):
|
|
|
intermediate_outs = encoder_out[1]
|
|
|
@@ -198,10 +198,10 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|
|
orders = [np.arange(e.shape[0]) for e in emb]
|
|
|
for order in orders:
|
|
|
np.random.shuffle(order)
|
|
|
- attractors, probs = self.encoder_decoder_attractor.estimate(
|
|
|
+ attractors, probs = self.eda.estimate(
|
|
|
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
|
|
|
else:
|
|
|
- attractors, probs = self.encoder_decoder_attractor.estimate(emb)
|
|
|
+ attractors, probs = self.eda.estimate(emb)
|
|
|
attractors_active = []
|
|
|
for p, att, e in zip(probs, attractors, emb):
|
|
|
if n_speakers and n_speakers >= 0:
|