| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322 |
- from contextlib import contextmanager
- from distutils.version import LooseVersion
- from typing import Dict
- from typing import List
- from typing import Optional
- from typing import Tuple
- from typing import Union
- import logging
- import torch
- from funasr.metrics import ErrorCalculator
- from funasr.metrics.compute_acc import th_accuracy
- from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
- from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
- )
- from funasr.models.ctc import CTC
- from funasr.models.decoder.abs_decoder import AbsDecoder
- from funasr.models.encoder.abs_encoder import AbsEncoder
- from funasr.frontends.abs_frontend import AbsFrontend
- from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
- from funasr.models.specaug.abs_specaug import AbsSpecAug
- from funasr.layers.abs_normalize import AbsNormalize
- from funasr.train_utils.device_funcs import force_gatherable
- from funasr.models.base_model import FunASRModel
- if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
- from torch.cuda.amp import autocast
- else:
- # Nothing to do if torch<1.6.0
- @contextmanager
- def autocast(enabled=True):
- yield
- import pdb
- import random
- import math
- class MFCCA(FunASRModel):
- """
- Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
- MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
- https://arxiv.org/abs/2210.05265
- """
- def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: AbsEncoder,
- decoder: AbsDecoder,
- ctc: CTC,
- rnnt_decoder: None = None,
- ctc_weight: float = 0.5,
- ignore_id: int = -1,
- lsm_weight: float = 0.0,
- mask_ratio: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- preencoder: Optional[AbsPreEncoder] = None,
- ):
- assert 0.0 <= ctc_weight <= 1.0, ctc_weight
- assert rnnt_decoder is None, "Not implemented"
- super().__init__()
- # note that eos is the same as sos (equivalent ID)
- self.sos = vocab_size - 1
- self.eos = vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.ctc_weight = ctc_weight
- self.token_list = token_list.copy()
- self.mask_ratio = mask_ratio
- self.frontend = frontend
- self.specaug = specaug
- self.normalize = normalize
- self.preencoder = preencoder
- self.encoder = encoder
- # we set self.decoder = None in the CTC mode since
- # self.decoder parameters were never used and PyTorch complained
- # and threw an Exception in the multi-GPU experiment.
- # thanks Jeff Farris for pointing out the issue.
- if ctc_weight == 1.0:
- self.decoder = None
- else:
- self.decoder = decoder
- if ctc_weight == 0.0:
- self.ctc = None
- else:
- self.ctc = ctc
- self.rnnt_decoder = rnnt_decoder
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- if report_cer or report_wer:
- self.error_calculator = ErrorCalculator(
- token_list, sym_space, sym_blank, report_cer, report_wer
- )
- else:
- self.error_calculator = None
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Frontend + Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
- assert text_lengths.dim() == 1, text_lengths.shape
- # Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- # pdb.set_trace()
- if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
- rate_num = random.random()
- # rate_num = 0.1
- if (rate_num <= self.mask_ratio):
- retain_channel = math.ceil(random.random() * 8)
- if (retain_channel > 1):
- speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
- else:
- speech = speech[:, :, torch.randperm(8)[0]]
- # pdb.set_trace()
- batch_size = speech.shape[0]
- # for data-parallel
- text = text[:, : text_lengths.max()]
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- # 2a. Attention-decoder branch
- if self.ctc_weight == 1.0:
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- else:
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 2b. CTC branch
- if self.ctc_weight == 0.0:
- loss_ctc, cer_ctc = None, None
- else:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
- # 2c. RNN-T branch
- if self.rnnt_decoder is not None:
- _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
- if self.ctc_weight == 0.0:
- loss = loss_att
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
- stats = dict(
- loss=loss.detach(),
- loss_att=loss_att.detach() if loss_att is not None else None,
- loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
- acc=acc_att,
- cer=cer_att,
- wer=wer_att,
- cer_ctc=cer_ctc,
- )
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
- def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- ) -> Dict[str, torch.Tensor]:
- feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
- return {"feats": feats, "feats_lengths": feats_lengths}
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- """
- with autocast(False):
- # 1. Extract feats
- feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
- # 2. Data augmentation
- if self.specaug is not None and self.training:
- feats, feats_lengths = self.specaug(feats, feats_lengths)
- # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
- # pdb.set_trace()
- encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
- assert encoder_out.size(0) == speech.size(0), (
- encoder_out.size(),
- speech.size(0),
- )
- if (encoder_out.dim() == 4):
- assert encoder_out.size(2) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
- else:
- assert encoder_out.size(1) <= encoder_out_lens.max(), (
- encoder_out.size(),
- encoder_out_lens.max(),
- )
- return encoder_out, encoder_out_lens
- def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- assert speech_lengths.dim() == 1, speech_lengths.shape
- # for data-parallel
- speech = speech[:, : speech_lengths.max()]
- if self.frontend is not None:
- # Frontend
- # e.g. STFT and Feature extract
- # data_loader may send time-domain signal in this case
- # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
- feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
- else:
- # No frontend and no feature extract
- feats, feats_lengths = speech, speech_lengths
- channel_size = 1
- return feats, feats_lengths, channel_size
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
- # 1. Forward decoder
- decoder_out, _ = self.decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- )
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_out_pad)
- acc_att = th_accuracy(
- decoder_out.view(-1, self.vocab_size),
- ys_out_pad,
- ignore_label=self.ignore_id,
- )
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
- return loss_att, acc_att, cer_att, wer_att
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- if (encoder_out.dim() == 4):
- encoder_out = encoder_out.mean(1)
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
- def _calc_rnnt_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- raise NotImplementedError
|