|
|
@@ -1,21 +1,21 @@
|
|
|
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
|
|
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
-
|
|
|
from contextlib import contextmanager
|
|
|
from distutils.version import LooseVersion
|
|
|
-from typing import Dict
|
|
|
-from typing import Tuple
|
|
|
+from typing import Dict, List, Tuple, Optional
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+from typeguard import check_argument_types
|
|
|
|
|
|
+from funasr.models.base_model import FunASRModel
|
|
|
from funasr.models.frontend.wav_frontend import WavFrontendMel23
|
|
|
from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
|
|
|
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
|
|
+from funasr.modules.eend_ola.utils.losses import fast_batch_pit_n_speaker_loss, standard_loss, cal_power_loss
|
|
|
+from funasr.modules.eend_ola.utils.power import create_powerlabel
|
|
|
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
|
|
|
from funasr.torch_utils.device_funcs import force_gatherable
|
|
|
-from funasr.models.base_model import FunASRModel
|
|
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
|
|
pass
|
|
|
@@ -33,12 +33,35 @@ def pad_attractor(att, max_n_speakers):
|
|
|
return att
|
|
|
|
|
|
|
|
|
+def pad_labels(ts, out_size):
|
|
|
+ for i, t in enumerate(ts):
|
|
|
+ if t.shape[1] < out_size:
|
|
|
+ ts[i] = F.pad(
|
|
|
+ t,
|
|
|
+ (0, out_size - t.shape[1], 0, 0),
|
|
|
+ mode='constant',
|
|
|
+ value=0.
|
|
|
+ )
|
|
|
+ return ts
|
|
|
+
|
|
|
+
|
|
|
+def pad_results(ys, out_size):
|
|
|
+ ys_padded = []
|
|
|
+ for i, y in enumerate(ys):
|
|
|
+ if y.shape[1] < out_size:
|
|
|
+ ys_padded.append(
|
|
|
+ torch.cat([y, torch.zeros(y.shape[0], out_size - y.shape[1]).to(torch.float32).to(y.device)], dim=1))
|
|
|
+ else:
|
|
|
+ ys_padded.append(y)
|
|
|
+ return ys_padded
|
|
|
+
|
|
|
+
|
|
|
class DiarEENDOLAModel(FunASRModel):
|
|
|
"""EEND-OLA diarization model"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- frontend: WavFrontendMel23,
|
|
|
+ frontend: Optional[WavFrontendMel23],
|
|
|
encoder: EENDOLATransformerEncoder,
|
|
|
encoder_decoder_attractor: EncoderDecoderAttractor,
|
|
|
n_units: int = 256,
|
|
|
@@ -47,11 +70,12 @@ class DiarEENDOLAModel(FunASRModel):
|
|
|
mapping_dict=None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
+ assert check_argument_types()
|
|
|
|
|
|
super().__init__()
|
|
|
self.frontend = frontend
|
|
|
self.enc = encoder
|
|
|
- self.eda = encoder_decoder_attractor
|
|
|
+ self.encoder_decoder_attractor = encoder_decoder_attractor
|
|
|
self.attractor_loss_weight = attractor_loss_weight
|
|
|
self.max_n_speaker = max_n_speaker
|
|
|
if mapping_dict is None:
|
|
|
@@ -74,7 +98,8 @@ class DiarEENDOLAModel(FunASRModel):
|
|
|
def forward_post_net(self, logits, ilens):
|
|
|
maxlen = torch.max(ilens).to(torch.int).item()
|
|
|
logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1)
|
|
|
- logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False)
|
|
|
+ logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True,
|
|
|
+ enforce_sorted=False)
|
|
|
outputs, (_, _) = self.postnet(logits)
|
|
|
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0]
|
|
|
outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)]
|
|
|
@@ -83,95 +108,51 @@ class DiarEENDOLAModel(FunASRModel):
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
- speech: torch.Tensor,
|
|
|
- speech_lengths: torch.Tensor,
|
|
|
- text: torch.Tensor,
|
|
|
- text_lengths: torch.Tensor,
|
|
|
+ speech: List[torch.Tensor],
|
|
|
+ speech_lengths: torch.Tensor, # num_frames of each sample
|
|
|
+ speaker_labels: List[torch.Tensor],
|
|
|
+ speaker_labels_lengths: torch.Tensor, # num_speakers of each sample
|
|
|
+ orders: torch.Tensor,
|
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
|
|
- """Frontend + Encoder + Decoder + Calc loss
|
|
|
- Args:
|
|
|
- speech: (Batch, Length, ...)
|
|
|
- speech_lengths: (Batch, )
|
|
|
- text: (Batch, Length)
|
|
|
- text_lengths: (Batch,)
|
|
|
- """
|
|
|
- assert text_lengths.dim() == 1, text_lengths.shape
|
|
|
+
|
|
|
# Check that batch_size is unified
|
|
|
assert (
|
|
|
- speech.shape[0]
|
|
|
- == speech_lengths.shape[0]
|
|
|
- == text.shape[0]
|
|
|
- == text_lengths.shape[0]
|
|
|
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
|
|
|
- batch_size = speech.shape[0]
|
|
|
-
|
|
|
- # for data-parallel
|
|
|
- text = text[:, : text_lengths.max()]
|
|
|
-
|
|
|
- # 1. Encoder
|
|
|
- encoder_out, encoder_out_lens = self.enc(speech, speech_lengths)
|
|
|
- intermediate_outs = None
|
|
|
- if isinstance(encoder_out, tuple):
|
|
|
- intermediate_outs = encoder_out[1]
|
|
|
- encoder_out = encoder_out[0]
|
|
|
-
|
|
|
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
|
|
|
- loss_ctc, cer_ctc = None, None
|
|
|
- stats = dict()
|
|
|
+ len(speech)
|
|
|
+ == len(speech_lengths)
|
|
|
+ == len(speaker_labels)
|
|
|
+ == len(speaker_labels_lengths)
|
|
|
+ ), (len(speech), len(speech_lengths), len(speaker_labels), len(speaker_labels_lengths))
|
|
|
+ batch_size = len(speech)
|
|
|
+
|
|
|
+ # Encoder
|
|
|
+ speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
|
|
|
+ encoder_out = self.forward_encoder(speech, speech_lengths)
|
|
|
|
|
|
- # 1. CTC branch
|
|
|
- if self.ctc_weight != 0.0:
|
|
|
- loss_ctc, cer_ctc = self._calc_ctc_loss(
|
|
|
- encoder_out, encoder_out_lens, text, text_lengths
|
|
|
- )
|
|
|
+ # Encoder-decoder attractor
|
|
|
+ attractor_loss, attractors = self.encoder_decoder_attractor([e[order] for e, order in zip(encoder_out, orders)],
|
|
|
+ speaker_labels_lengths)
|
|
|
+ speaker_logits = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(encoder_out, attractors)]
|
|
|
|
|
|
- # Collect CTC branch stats
|
|
|
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
|
|
|
- stats["cer_ctc"] = cer_ctc
|
|
|
-
|
|
|
- # Intermediate CTC (optional)
|
|
|
- loss_interctc = 0.0
|
|
|
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
|
|
|
- for layer_idx, intermediate_out in intermediate_outs:
|
|
|
- # we assume intermediate_out has the same length & padding
|
|
|
- # as those of encoder_out
|
|
|
- loss_ic, cer_ic = self._calc_ctc_loss(
|
|
|
- intermediate_out, encoder_out_lens, text, text_lengths
|
|
|
- )
|
|
|
- loss_interctc = loss_interctc + loss_ic
|
|
|
-
|
|
|
- # Collect Intermedaite CTC stats
|
|
|
- stats["loss_interctc_layer{}".format(layer_idx)] = (
|
|
|
- loss_ic.detach() if loss_ic is not None else None
|
|
|
- )
|
|
|
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
|
|
|
-
|
|
|
- loss_interctc = loss_interctc / len(intermediate_outs)
|
|
|
-
|
|
|
- # calculate whole encoder loss
|
|
|
- loss_ctc = (
|
|
|
- 1 - self.interctc_weight
|
|
|
- ) * loss_ctc + self.interctc_weight * loss_interctc
|
|
|
-
|
|
|
- # 2b. Attention decoder branch
|
|
|
- if self.ctc_weight != 1.0:
|
|
|
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
|
|
|
- encoder_out, encoder_out_lens, text, text_lengths
|
|
|
- )
|
|
|
+ # pit loss
|
|
|
+ pit_speaker_labels = fast_batch_pit_n_speaker_loss(speaker_logits, speaker_labels)
|
|
|
+ pit_loss = standard_loss(speaker_logits, pit_speaker_labels)
|
|
|
|
|
|
- # 3. CTC-Att loss definition
|
|
|
- if self.ctc_weight == 0.0:
|
|
|
- loss = loss_att
|
|
|
- elif self.ctc_weight == 1.0:
|
|
|
- loss = loss_ctc
|
|
|
- else:
|
|
|
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
|
|
|
+ # pse loss
|
|
|
+ with torch.no_grad():
|
|
|
+ power_ts = [create_powerlabel(label.cpu().numpy(), self.mapping_dict, self.max_n_speaker).
|
|
|
+ to(encoder_out[0].device, non_blocking=True) for label in pit_speaker_labels]
|
|
|
+ pad_attractors = [pad_attractor(att, self.max_n_speaker) for att in attractors]
|
|
|
+ pse_speaker_logits = [torch.matmul(e, pad_att.permute(1, 0)) for e, pad_att in zip(encoder_out, pad_attractors)]
|
|
|
+ pse_speaker_logits = self.forward_post_net(pse_speaker_logits, speech_lengths)
|
|
|
+ pse_loss = cal_power_loss(pse_speaker_logits, power_ts)
|
|
|
|
|
|
- # Collect Attn branch stats
|
|
|
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
|
|
|
- stats["acc"] = acc_att
|
|
|
- stats["cer"] = cer_att
|
|
|
- stats["wer"] = wer_att
|
|
|
+ loss = pse_loss + pit_loss + self.attractor_loss_weight * attractor_loss
|
|
|
+
|
|
|
+ stats = dict()
|
|
|
+ stats["pse_loss"] = pse_loss.detach()
|
|
|
+ stats["pit_loss"] = pit_loss.detach()
|
|
|
+ stats["attractor_loss"] = attractor_loss.detach()
|
|
|
+ stats["batch_size"] = batch_size
|
|
|
|
|
|
# Collect total loss stats
|
|
|
stats["loss"] = torch.clone(loss.detach())
|
|
|
@@ -193,10 +174,10 @@ class DiarEENDOLAModel(FunASRModel):
|
|
|
orders = [np.arange(e.shape[0]) for e in emb]
|
|
|
for order in orders:
|
|
|
np.random.shuffle(order)
|
|
|
- attractors, probs = self.eda.estimate(
|
|
|
+ attractors, probs = self.encoder_decoder_attractor.estimate(
|
|
|
[e[torch.from_numpy(order).to(torch.long).to(speech[0].device)] for e, order in zip(emb, orders)])
|
|
|
else:
|
|
|
- attractors, probs = self.eda.estimate(emb)
|
|
|
+ attractors, probs = self.encoder_decoder_attractor.estimate(emb)
|
|
|
attractors_active = []
|
|
|
for p, att, e in zip(probs, attractors, emb):
|
|
|
if n_speakers and n_speakers >= 0:
|