|
|
@@ -1,38 +1,24 @@
|
|
|
# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
|
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
|
|
-import logging
|
|
|
-import torch
|
|
|
from contextlib import contextmanager
|
|
|
from distutils.version import LooseVersion
|
|
|
-from funasr.layers.abs_normalize import AbsNormalize
|
|
|
-from funasr.losses.label_smoothing_loss import (
|
|
|
- LabelSmoothingLoss, # noqa: H301
|
|
|
-)
|
|
|
-from funasr.models.ctc import CTC
|
|
|
-from funasr.models.decoder.abs_decoder import AbsDecoder
|
|
|
-from funasr.models.encoder.abs_encoder import AbsEncoder
|
|
|
-from funasr.models.frontend.abs_frontend import AbsFrontend
|
|
|
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
|
|
|
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
|
|
-from funasr.models.specaug.abs_specaug import AbsSpecAug
|
|
|
-from funasr.modules.add_sos_eos import add_sos_eos
|
|
|
-from funasr.modules.e2e_asr_common import ErrorCalculator
|
|
|
+from typing import Dict
|
|
|
+from typing import Tuple
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+from typeguard import check_argument_types
|
|
|
+
|
|
|
from funasr.modules.eend_ola.encoder import TransformerEncoder
|
|
|
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
|
|
|
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
|
|
|
-from funasr.modules.nets_utils import th_accuracy
|
|
|
from funasr.torch_utils.device_funcs import force_gatherable
|
|
|
from funasr.train.abs_espnet_model import AbsESPnetModel
|
|
|
-from typeguard import check_argument_types
|
|
|
-from typing import Dict
|
|
|
-from typing import List
|
|
|
-from typing import Optional
|
|
|
-from typing import Tuple
|
|
|
-from typing import Union
|
|
|
|
|
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
|
|
- from torch.cuda.amp import autocast
|
|
|
+ pass
|
|
|
else:
|
|
|
# Nothing to do if torch<1.6.0
|
|
|
@contextmanager
|
|
|
@@ -47,6 +33,7 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|
|
self,
|
|
|
encoder: TransformerEncoder,
|
|
|
eda: EncoderDecoderAttractor,
|
|
|
+ n_units: int = 256,
|
|
|
max_n_speaker: int = 8,
|
|
|
attractor_loss_weight: float = 1.0,
|
|
|
mapping_dict=None,
|
|
|
@@ -62,6 +49,9 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|
|
if mapping_dict is None:
|
|
|
mapping_dict = generate_mapping_dict(max_speaker_num=self.max_n_speaker)
|
|
|
self.mapping_dict = mapping_dict
|
|
|
+ # PostNet
|
|
|
+ self.PostNet = nn.LSTM(self.max_n_speaker, n_units, 1, batch_first=True)
|
|
|
+ self.output_layer = nn.Linear(n_units, mapping_dict['oov'] + 1)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
@@ -163,233 +153,65 @@ class DiarEENDOLAModel(AbsESPnetModel):
|
|
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
|
return loss, stats, weight
|
|
|
|
|
|
- def collect_feats(
|
|
|
- self,
|
|
|
- speech: torch.Tensor,
|
|
|
- speech_lengths: torch.Tensor,
|
|
|
- text: torch.Tensor,
|
|
|
- text_lengths: torch.Tensor,
|
|
|
- ) -> Dict[str, torch.Tensor]:
|
|
|
- if self.extract_feats_in_collect_stats:
|
|
|
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
|
|
+ def estimate_sequential(self,
|
|
|
+ speech: torch.Tensor,
|
|
|
+ speech_lengths: torch.Tensor,
|
|
|
+ n_speakers: int,
|
|
|
+ shuffle: bool,
|
|
|
+ threshold: float,
|
|
|
+ **kwargs):
|
|
|
+ speech = [s[:s_len] for s, s_len in zip(speech, speech_lengths)]
|
|
|
+ emb = self.forward_core(speech) # list, [(T1, C1), ..., (T1, C1)]
|
|
|
+ if shuffle:
|
|
|
+ orders = [np.arange(e.shape[0]) for e in emb]
|
|
|
+ for order in orders:
|
|
|
+ np.random.shuffle(order)
|
|
|
+ # e[order]: shuffle后的embeddings, list, [(T1, C1), ..., (T1, C1)] 每个sample的T维度已进行随机顺序交换
|
|
|
+ # attractors, list, hts(论文里的as), [(max_n_speakers, n_units), ..., (max_n_speakers, n_units)]
|
|
|
+ # probs, list, [(max_n_speakers, ), ..., (max_n_speakers, ]
|
|
|
+ attractors, probs = self.eda.estimate(
|
|
|
+ [e[torch.from_numpy(order).to(torch.long).to(xs[0].device)] for e, order in zip(emb, orders)])
|
|
|
else:
|
|
|
- # Generate dummy stats if extract_feats_in_collect_stats is False
|
|
|
- logging.warning(
|
|
|
- "Generating dummy stats for feats and feats_lengths, "
|
|
|
- "because encoder_conf.extract_feats_in_collect_stats is "
|
|
|
- f"{self.extract_feats_in_collect_stats}"
|
|
|
- )
|
|
|
- feats, feats_lengths = speech, speech_lengths
|
|
|
- return {"feats": feats, "feats_lengths": feats_lengths}
|
|
|
-
|
|
|
- def encode(
|
|
|
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
|
|
- ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- """Frontend + Encoder. Note that this method is used by asr_inference.py
|
|
|
-
|
|
|
- Args:
|
|
|
- speech: (Batch, Length, ...)
|
|
|
- speech_lengths: (Batch, )
|
|
|
- """
|
|
|
- with autocast(False):
|
|
|
- # 1. Extract feats
|
|
|
- feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
|
|
-
|
|
|
- # 2. Data augmentation
|
|
|
- if self.specaug is not None and self.training:
|
|
|
- feats, feats_lengths = self.specaug(feats, feats_lengths)
|
|
|
-
|
|
|
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
|
|
- if self.normalize is not None:
|
|
|
- feats, feats_lengths = self.normalize(feats, feats_lengths)
|
|
|
-
|
|
|
- # Pre-encoder, e.g. used for raw input data
|
|
|
- if self.preencoder is not None:
|
|
|
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
|
|
-
|
|
|
- # 4. Forward encoder
|
|
|
- # feats: (Batch, Length, Dim)
|
|
|
- # -> encoder_out: (Batch, Length2, Dim2)
|
|
|
- if self.encoder.interctc_use_conditioning:
|
|
|
- encoder_out, encoder_out_lens, _ = self.encoder(
|
|
|
- feats, feats_lengths, ctc=self.ctc
|
|
|
- )
|
|
|
- else:
|
|
|
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
|
|
- intermediate_outs = None
|
|
|
- if isinstance(encoder_out, tuple):
|
|
|
- intermediate_outs = encoder_out[1]
|
|
|
- encoder_out = encoder_out[0]
|
|
|
-
|
|
|
- # Post-encoder, e.g. NLU
|
|
|
- if self.postencoder is not None:
|
|
|
- encoder_out, encoder_out_lens = self.postencoder(
|
|
|
- encoder_out, encoder_out_lens
|
|
|
- )
|
|
|
-
|
|
|
- assert encoder_out.size(0) == speech.size(0), (
|
|
|
- encoder_out.size(),
|
|
|
- speech.size(0),
|
|
|
- )
|
|
|
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
|
|
|
- encoder_out.size(),
|
|
|
- encoder_out_lens.max(),
|
|
|
- )
|
|
|
-
|
|
|
- if intermediate_outs is not None:
|
|
|
- return (encoder_out, intermediate_outs), encoder_out_lens
|
|
|
-
|
|
|
- return encoder_out, encoder_out_lens
|
|
|
-
|
|
|
- def _extract_feats(
|
|
|
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
|
|
- ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- assert speech_lengths.dim() == 1, speech_lengths.shape
|
|
|
-
|
|
|
- # for data-parallel
|
|
|
- speech = speech[:, : speech_lengths.max()]
|
|
|
-
|
|
|
- if self.frontend is not None:
|
|
|
- # Frontend
|
|
|
- # e.g. STFT and Feature extract
|
|
|
- # data_loader may send time-domain signal in this case
|
|
|
- # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
|
|
- feats, feats_lengths = self.frontend(speech, speech_lengths)
|
|
|
- else:
|
|
|
- # No frontend and no feature extract
|
|
|
- feats, feats_lengths = speech, speech_lengths
|
|
|
- return feats, feats_lengths
|
|
|
-
|
|
|
- def nll(
|
|
|
- self,
|
|
|
- encoder_out: torch.Tensor,
|
|
|
- encoder_out_lens: torch.Tensor,
|
|
|
- ys_pad: torch.Tensor,
|
|
|
- ys_pad_lens: torch.Tensor,
|
|
|
- ) -> torch.Tensor:
|
|
|
- """Compute negative log likelihood(nll) from transformer-decoder
|
|
|
-
|
|
|
- Normally, this function is called in batchify_nll.
|
|
|
-
|
|
|
- Args:
|
|
|
- encoder_out: (Batch, Length, Dim)
|
|
|
- encoder_out_lens: (Batch,)
|
|
|
- ys_pad: (Batch, Length)
|
|
|
- ys_pad_lens: (Batch,)
|
|
|
- """
|
|
|
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
|
|
- ys_in_lens = ys_pad_lens + 1
|
|
|
-
|
|
|
- # 1. Forward decoder
|
|
|
- decoder_out, _ = self.decoder(
|
|
|
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
|
|
- ) # [batch, seqlen, dim]
|
|
|
- batch_size = decoder_out.size(0)
|
|
|
- decoder_num_class = decoder_out.size(2)
|
|
|
- # nll: negative log-likelihood
|
|
|
- nll = torch.nn.functional.cross_entropy(
|
|
|
- decoder_out.view(-1, decoder_num_class),
|
|
|
- ys_out_pad.view(-1),
|
|
|
- ignore_index=self.ignore_id,
|
|
|
- reduction="none",
|
|
|
- )
|
|
|
- nll = nll.view(batch_size, -1)
|
|
|
- nll = nll.sum(dim=1)
|
|
|
- assert nll.size(0) == batch_size
|
|
|
- return nll
|
|
|
-
|
|
|
- def batchify_nll(
|
|
|
- self,
|
|
|
- encoder_out: torch.Tensor,
|
|
|
- encoder_out_lens: torch.Tensor,
|
|
|
- ys_pad: torch.Tensor,
|
|
|
- ys_pad_lens: torch.Tensor,
|
|
|
- batch_size: int = 100,
|
|
|
- ):
|
|
|
- """Compute negative log likelihood(nll) from transformer-decoder
|
|
|
-
|
|
|
- To avoid OOM, this fuction seperate the input into batches.
|
|
|
- Then call nll for each batch and combine and return results.
|
|
|
- Args:
|
|
|
- encoder_out: (Batch, Length, Dim)
|
|
|
- encoder_out_lens: (Batch,)
|
|
|
- ys_pad: (Batch, Length)
|
|
|
- ys_pad_lens: (Batch,)
|
|
|
- batch_size: int, samples each batch contain when computing nll,
|
|
|
- you may change this to avoid OOM or increase
|
|
|
- GPU memory usage
|
|
|
- """
|
|
|
- total_num = encoder_out.size(0)
|
|
|
- if total_num <= batch_size:
|
|
|
- nll = self.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
|
|
- else:
|
|
|
- nll = []
|
|
|
- start_idx = 0
|
|
|
- while True:
|
|
|
- end_idx = min(start_idx + batch_size, total_num)
|
|
|
- batch_encoder_out = encoder_out[start_idx:end_idx, :, :]
|
|
|
- batch_encoder_out_lens = encoder_out_lens[start_idx:end_idx]
|
|
|
- batch_ys_pad = ys_pad[start_idx:end_idx, :]
|
|
|
- batch_ys_pad_lens = ys_pad_lens[start_idx:end_idx]
|
|
|
- batch_nll = self.nll(
|
|
|
- batch_encoder_out,
|
|
|
- batch_encoder_out_lens,
|
|
|
- batch_ys_pad,
|
|
|
- batch_ys_pad_lens,
|
|
|
- )
|
|
|
- nll.append(batch_nll)
|
|
|
- start_idx = end_idx
|
|
|
- if start_idx == total_num:
|
|
|
- break
|
|
|
- nll = torch.cat(nll)
|
|
|
- assert nll.size(0) == total_num
|
|
|
- return nll
|
|
|
-
|
|
|
- def _calc_att_loss(
|
|
|
- self,
|
|
|
- encoder_out: torch.Tensor,
|
|
|
- encoder_out_lens: torch.Tensor,
|
|
|
- ys_pad: torch.Tensor,
|
|
|
- ys_pad_lens: torch.Tensor,
|
|
|
- ):
|
|
|
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
|
|
|
- ys_in_lens = ys_pad_lens + 1
|
|
|
-
|
|
|
- # 1. Forward decoder
|
|
|
- decoder_out, _ = self.decoder(
|
|
|
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
|
|
|
- )
|
|
|
-
|
|
|
- # 2. Compute attention loss
|
|
|
- loss_att = self.criterion_att(decoder_out, ys_out_pad)
|
|
|
- acc_att = th_accuracy(
|
|
|
- decoder_out.view(-1, self.vocab_size),
|
|
|
- ys_out_pad,
|
|
|
- ignore_label=self.ignore_id,
|
|
|
- )
|
|
|
-
|
|
|
- # Compute cer/wer using attention-decoder
|
|
|
- if self.training or self.error_calculator is None:
|
|
|
- cer_att, wer_att = None, None
|
|
|
- else:
|
|
|
- ys_hat = decoder_out.argmax(dim=-1)
|
|
|
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
|
|
|
-
|
|
|
- return loss_att, acc_att, cer_att, wer_att
|
|
|
-
|
|
|
- def _calc_ctc_loss(
|
|
|
- self,
|
|
|
- encoder_out: torch.Tensor,
|
|
|
- encoder_out_lens: torch.Tensor,
|
|
|
- ys_pad: torch.Tensor,
|
|
|
- ys_pad_lens: torch.Tensor,
|
|
|
- ):
|
|
|
- # Calc CTC loss
|
|
|
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
|
|
-
|
|
|
- # Calc CER using CTC
|
|
|
- cer_ctc = None
|
|
|
- if not self.training and self.error_calculator is not None:
|
|
|
- ys_hat = self.ctc.argmax(encoder_out).data
|
|
|
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
|
|
- return loss_ctc, cer_ctc
|
|
|
+ attractors, probs = self.eda.estimate(emb)
|
|
|
+ attractors_active = []
|
|
|
+ for p, att, e in zip(probs, attractors, emb):
|
|
|
+ if n_speakers and n_speakers >= 0: # 根据指定说话人数, 选择对应数量的ys
|
|
|
+ # TODO:在测试有不同数量speaker数的数据集时,考虑改成根据sample来确定具体的speaker数,而不是直接指定
|
|
|
+ # raise NotImplementedError
|
|
|
+ att = att[:n_speakers, ]
|
|
|
+ attractors_active.append(att)
|
|
|
+ elif threshold is not None:
|
|
|
+ silence = torch.nonzero(p < threshold)[0] # 找到第一个输出概率小于阈值的索引, 作为结束, 且值刚好等于说话人数
|
|
|
+ n_spk = silence[0] if silence.size else None
|
|
|
+ att = att[:n_spk, ]
|
|
|
+ attractors_active.append(att)
|
|
|
+ else:
|
|
|
+ NotImplementedError('n_speakers or th has to be given.')
|
|
|
+ raw_n_speakers = [att.shape[0] for att in attractors_active] # [C1, C2, ..., CB]
|
|
|
+ attractors = [
|
|
|
+ pad_attractor(att, self.max_n_speaker) if att.shape[0] <= self.max_n_speaker else att[:self.max_n_speaker]
|
|
|
+ for att in attractors_active]
|
|
|
+ ys = [torch.matmul(e, att.permute(1, 0)) for e, att in zip(emb, attractors)]
|
|
|
+ # ys_eda = [torch.sigmoid(y[:, :n_spk]) for y,n_spk in zip(ys, raw_n_speakers)]
|
|
|
+ logits = self.cal_postnet(ys, self.max_n_speaker)
|
|
|
+ ys = [self.recover_y_from_powerlabel(logit, raw_n_speaker) for logit, raw_n_speaker in
|
|
|
+ zip(logits, raw_n_speakers)]
|
|
|
+
|
|
|
+ return ys, emb, attractors, raw_n_speakers
|
|
|
+
|
|
|
+ def recover_y_from_powerlabel(self, logit, n_speaker):
|
|
|
+ pred = torch.argmax(torch.softmax(logit, dim=-1), dim=-1) # (T, )
|
|
|
+ oov_index = torch.where(pred == self.mapping_dict['oov'])[0]
|
|
|
+ for i in oov_index:
|
|
|
+ if i > 0:
|
|
|
+ pred[i] = pred[i - 1]
|
|
|
+ else:
|
|
|
+ pred[i] = 0
|
|
|
+ pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred]
|
|
|
+ # print(pred)
|
|
|
+ decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred]
|
|
|
+ decisions = torch.from_numpy(
|
|
|
+ np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to(
|
|
|
+ torch.float32)
|
|
|
+ decisions = decisions[:, :n_speaker]
|
|
|
+ return decisions
|