Просмотр исходного кода

boundary aware transducer (#691)

* boundary aware transducer

* resolve conflict

* delete type check

---------

Co-authored-by: aky15 <ankeyu.aky@11.17.44.249>
aky15 2 лет назад
Родитель
Сommit
05ada32da8

+ 2 - 0
funasr/bin/asr_inference_launch.py

@@ -1604,6 +1604,8 @@ def inference_launch(**kwargs):
         return inference_mfcca(**kwargs)
         return inference_mfcca(**kwargs)
     elif mode == "rnnt":
     elif mode == "rnnt":
         return inference_transducer(**kwargs)
         return inference_transducer(**kwargs)
+    elif mode == "bat":
+        return inference_transducer(**kwargs)
     elif mode == "sa_asr":
     elif mode == "sa_asr":
         return inference_sa_asr(**kwargs)
         return inference_sa_asr(**kwargs)
     else:
     else:

+ 60 - 8
funasr/build_utils/build_asr_model.py

@@ -26,6 +26,7 @@ from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.models.e2e_asr_mfcca import MFCCA
 from funasr.models.e2e_asr_mfcca import MFCCA
 
 
 from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
 from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
+from funasr.models.e2e_asr_bat import BATModel
 
 
 from funasr.models.e2e_sa_asr import SAASRModel
 from funasr.models.e2e_sa_asr import SAASRModel
 from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
@@ -46,7 +47,7 @@ from funasr.models.frontend.s3prl import S3prlFrontend
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.models.frontend.windowing import SlidingWindow
 from funasr.models.frontend.windowing import SlidingWindow
 from funasr.models.joint_net.joint_network import JointNetwork
 from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
 from funasr.models.specaug.specaug import SpecAug
 from funasr.models.specaug.specaug import SpecAug
 from funasr.models.specaug.specaug import SpecAugLFR
 from funasr.models.specaug.specaug import SpecAugLFR
 from funasr.modules.subsampling import Conv1dSubsampling
 from funasr.modules.subsampling import Conv1dSubsampling
@@ -99,7 +100,7 @@ model_choices = ClassChoices(
         rnnt=TransducerModel,
         rnnt=TransducerModel,
         rnnt_unified=UnifiedTransducerModel,
         rnnt_unified=UnifiedTransducerModel,
         sa_asr=SAASRModel,
         sa_asr=SAASRModel,
-
+        bat=BATModel,
     ),
     ),
     default="asr",
     default="asr",
 )
 )
@@ -188,6 +189,7 @@ predictor_choices = ClassChoices(
         ctc_predictor=None,
         ctc_predictor=None,
         cif_predictor_v2=CifPredictorV2,
         cif_predictor_v2=CifPredictorV2,
         cif_predictor_v3=CifPredictorV3,
         cif_predictor_v3=CifPredictorV3,
+        bat_predictor=BATPredictor,
     ),
     ),
     default="cif_predictor",
     default="cif_predictor",
     optional=True,
     optional=True,
@@ -313,12 +315,15 @@ def build_asr_model(args):
     encoder = encoder_class(input_size=input_size, **args.encoder_conf)
     encoder = encoder_class(input_size=input_size, **args.encoder_conf)
 
 
     # decoder
     # decoder
-    decoder_class = decoder_choices.get_class(args.decoder)
-    decoder = decoder_class(
-        vocab_size=vocab_size,
-        encoder_output_size=encoder.output_size(),
-        **args.decoder_conf,
-    )
+    if hasattr(args, "decoder") and args.decoder is not None:
+        decoder_class = decoder_choices.get_class(args.decoder)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder.output_size(),
+            **args.decoder_conf,
+        )
+    else:
+        decoder = None
 
 
     # ctc
     # ctc
     ctc = CTC(
     ctc = CTC(
@@ -463,6 +468,53 @@ def build_asr_model(args):
             joint_network=joint_network,
             joint_network=joint_network,
             **args.model_conf,
             **args.model_conf,
         )
         )
+    elif args.model == "bat":
+        # 5. Decoder
+        encoder_output_size = encoder.output_size()
+
+        rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
+        decoder = rnnt_decoder_class(
+            vocab_size,
+            **args.rnnt_decoder_conf,
+        )
+        decoder_output_size = decoder.output_size
+
+        if getattr(args, "decoder", None) is not None:
+            att_decoder_class = decoder_choices.get_class(args.decoder)
+
+            att_decoder = att_decoder_class(
+                vocab_size=vocab_size,
+                encoder_output_size=encoder_output_size,
+                **args.decoder_conf,
+            )
+        else:
+            att_decoder = None
+        # 6. Joint Network
+        joint_network = JointNetwork(
+            vocab_size,
+            encoder_output_size,
+            decoder_output_size,
+            **args.joint_network_conf,
+        )
+
+        predictor_class = predictor_choices.get_class(args.predictor)
+        predictor = predictor_class(**args.predictor_conf)
+
+        model_class = model_choices.get_class(args.model)
+        # 7. Build model
+        model = model_class(
+            vocab_size=vocab_size,
+            token_list=token_list,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            encoder=encoder,
+            decoder=decoder,
+            att_decoder=att_decoder,
+            joint_network=joint_network,
+            predictor=predictor,
+            **args.model_conf,
+        )
     elif args.model == "sa_asr":
     elif args.model == "sa_asr":
         asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
         asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
         asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
         asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)

+ 496 - 0
funasr/models/e2e_asr_bat.py

@@ -0,0 +1,496 @@
+"""Boundary Aware Transducer (BAT) model."""
+
+import logging
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from packaging.version import parse as V
+from funasr.losses.label_smoothing_loss import (
+    LabelSmoothingLoss,  # noqa: H301
+)
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.joint_net.joint_network import JointNetwork
+from funasr.modules.nets_utils import get_transducer_task_io
+from funasr.modules.nets_utils import th_accuracy
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.models.base_model import FunASRModel
+
+if V(torch.__version__) >= V("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+class BATModel(FunASRModel):
+    """BATModel module definition.
+
+    Args:
+        vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+        token_list: List of token
+        frontend: Frontend module.
+        specaug: SpecAugment module.
+        normalize: Normalization module.
+        encoder: Encoder module.
+        decoder: Decoder module.
+        joint_network: Joint Network module.
+        transducer_weight: Weight of the Transducer loss.
+        fastemit_lambda: FastEmit lambda value.
+        auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+        auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+        auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+        auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+        ignore_id: Initial padding ID.
+        sym_space: Space symbol.
+        sym_blank: Blank Symbol
+        report_cer: Whether to report Character Error Rate during validation.
+        report_wer: Whether to report Word Error Rate during validation.
+        extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        token_list: Union[Tuple[str, ...], List[str]],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
+        encoder: AbsEncoder,
+        decoder: RNNTDecoder,
+        joint_network: JointNetwork,
+        att_decoder: Optional[AbsAttDecoder] = None,
+        predictor = None,
+        transducer_weight: float = 1.0,
+        predictor_weight: float = 1.0,
+        cif_weight: float = 1.0,
+        fastemit_lambda: float = 0.0,
+        auxiliary_ctc_weight: float = 0.0,
+        auxiliary_ctc_dropout_rate: float = 0.0,
+        auxiliary_lm_loss_weight: float = 0.0,
+        auxiliary_lm_loss_smoothing: float = 0.0,
+        ignore_id: int = -1,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        report_cer: bool = True,
+        report_wer: bool = True,
+        extract_feats_in_collect_stats: bool = True,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        r_d: int = 5,
+        r_u: int = 5,
+    ) -> None:
+        """Construct an BATModel object."""
+        super().__init__()
+
+        # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+        self.blank_id = 0
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.token_list = token_list.copy()
+
+        self.sym_space = sym_space
+        self.sym_blank = sym_blank
+
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+
+        self.encoder = encoder
+        self.decoder = decoder
+        self.joint_network = joint_network
+
+        self.criterion_transducer = None
+        self.error_calculator = None
+
+        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+        if self.use_auxiliary_ctc:
+            self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
+            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+        if self.use_auxiliary_lm_loss:
+            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+        self.transducer_weight = transducer_weight
+        self.fastemit_lambda = fastemit_lambda
+
+        self.auxiliary_ctc_weight = auxiliary_ctc_weight
+        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+        self.report_cer = report_cer
+        self.report_wer = report_wer
+
+        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+        self.criterion_pre = torch.nn.L1Loss()
+        self.predictor_weight = predictor_weight
+        self.predictor = predictor
+        
+        self.cif_weight = cif_weight
+        if self.cif_weight > 0:
+            self.cif_output_layer = torch.nn.Linear(encoder.output_size(), vocab_size)
+            self.criterion_cif = LabelSmoothingLoss(
+                size=vocab_size,
+                padding_idx=ignore_id,
+                smoothing=lsm_weight,
+                normalize_length=length_normalized_loss,
+            )        
+        self.r_d = r_d
+        self.r_u = r_u
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Forward architecture and compute loss(es).
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+
+        Return:
+            loss: Main loss value.
+            stats: Task statistics.
+            weight: Task weights.
+
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        assert (
+            speech.shape[0]
+            == speech_lengths.shape[0]
+            == text.shape[0]
+            == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+        batch_size = speech.shape[0]
+        text = text[:, : text_lengths.max()]
+
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
+            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
+                                                                                        chunk_outs=None)
+
+        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
+        # 2. Transducer-related I/O preparation
+        decoder_in, target, t_len, u_len = get_transducer_task_io(
+            text,
+            encoder_out_lens,
+            ignore_id=self.ignore_id,
+        )
+
+        # 3. Decoder
+        self.decoder.set_device(encoder_out.device)
+        decoder_out = self.decoder(decoder_in, u_len)
+
+        pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=self.ignore_id)
+        loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length), pre_token_length)
+
+        if self.cif_weight > 0.0:
+            cif_predict = self.cif_output_layer(pre_acoustic_embeds)
+            loss_cif = self.criterion_cif(cif_predict, text)
+        else:
+            loss_cif = 0.0
+
+        # 5. Losses
+        boundary = torch.zeros((encoder_out.size(0), 4), dtype=torch.int64, device=encoder_out.device)
+        boundary[:, 2] = u_len.long().detach()
+        boundary[:, 3] = t_len.long().detach()
+
+        pre_peak_index = torch.floor(pre_peak_index).long()
+        s_begin = pre_peak_index - self.r_d
+
+        T = encoder_out.size(1)
+        B = encoder_out.size(0)
+        U = decoder_out.size(1)
+
+        mask = torch.arange(0, T, device=encoder_out.device).reshape(1, T).expand(B, T)
+        mask = mask <= boundary[:, 3].reshape(B, 1) - 1
+
+        s_begin_padding = boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
+        # handle the cases where `len(symbols) < s_range`
+        s_begin_padding = torch.clamp(s_begin_padding, min=0)
+
+        s_begin = torch.where(mask, s_begin, s_begin_padding)
+        
+        mask2 = s_begin <  boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
+
+        s_begin = torch.where(mask2, s_begin, boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1)
+
+        s_begin = torch.clamp(s_begin, min=0)
+        
+        ranges = s_begin.reshape((B, T, 1)).expand((B, T, min(self.r_u+self.r_d, min(u_len)))) + torch.arange(min(self.r_d+self.r_u, min(u_len)), device=encoder_out.device)
+
+        import fast_rnnt
+        am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
+            am=self.joint_network.lin_enc(encoder_out),
+            lm=self.joint_network.lin_dec(decoder_out),
+            ranges=ranges,
+        )
+
+        logits = self.joint_network(am_pruned, lm_pruned, project_input=False)
+
+        with torch.cuda.amp.autocast(enabled=False):
+            loss_trans = fast_rnnt.rnnt_loss_pruned(
+                logits=logits.float(),
+                symbols=target.long(),
+                ranges=ranges,
+                termination_symbol=self.blank_id,
+                boundary=boundary,
+                reduction="sum",
+            )
+
+        cer_trans, wer_trans = None, None
+        if not self.training and (self.report_cer or self.report_wer):
+            if self.error_calculator is None:
+                from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
+                self.error_calculator = ErrorCalculator(
+                    self.decoder,
+                    self.joint_network,
+                    self.token_list,
+                    self.sym_space,
+                    self.sym_blank,
+                    report_cer=self.report_cer,
+                    report_wer=self.report_wer,
+                )
+            cer_trans, wer_trans = self.error_calculator(encoder_out, target, t_len)
+
+        loss_ctc, loss_lm = 0.0, 0.0
+
+        if self.use_auxiliary_ctc:
+            loss_ctc = self._calc_ctc_loss(
+                encoder_out,
+                target,
+                t_len,
+                u_len,
+            )
+
+        if self.use_auxiliary_lm_loss:
+            loss_lm = self._calc_lm_loss(decoder_out, target)
+
+        loss = (
+            self.transducer_weight * loss_trans
+            + self.auxiliary_ctc_weight * loss_ctc
+            + self.auxiliary_lm_loss_weight * loss_lm
+            + self.predictor_weight * loss_pre
+            + self.cif_weight * loss_cif
+        )
+
+        stats = dict(
+            loss=loss.detach(),
+            loss_transducer=loss_trans.detach(),
+            loss_pre=loss_pre.detach(),
+            loss_cif=loss_cif.detach() if loss_cif > 0.0 else None,
+            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+            cer_transducer=cer_trans,
+            wer_transducer=wer_trans,
+        )
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+
+        return loss, stats, weight
+
+    def collect_feats(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Dict[str, torch.Tensor]:
+        """Collect features sequences and features lengths sequences.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+
+        Return:
+            {}: "feats": Features sequences. (B, T, D_feats),
+                "feats_lengths": Features sequences lengths. (B,)
+
+        """
+        if self.extract_feats_in_collect_stats:
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        else:
+            # Generate dummy stats if extract_feats_in_collect_stats is False
+            logging.warning(
+                "Generating dummy stats for feats and feats_lengths, "
+                "because encoder_conf.extract_feats_in_collect_stats is "
+                f"{self.extract_feats_in_collect_stats}"
+            )
+
+            feats, feats_lengths = speech, speech_lengths
+
+        return {"feats": feats, "feats_lengths": feats_lengths}
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder speech sequences.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+
+        Return:
+            encoder_out: Encoder outputs. (B, T, D_enc)
+            encoder_out_lens: Encoder outputs lengths. (B,)
+
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # 4. Forward encoder
+        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+
+        assert encoder_out.size(0) == speech.size(0), (
+            encoder_out.size(),
+            speech.size(0),
+        )
+        assert encoder_out.size(1) <= encoder_out_lens.max(), (
+            encoder_out.size(),
+            encoder_out_lens.max(),
+        )
+
+        return encoder_out, encoder_out_lens
+
+    def _extract_feats(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Extract features sequences and features sequences lengths.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+
+        Return:
+            feats: Features sequences. (B, T, D_feats)
+            feats_lengths: Features sequences lengths. (B,)
+
+        """
+        assert speech_lengths.dim() == 1, speech_lengths.shape
+
+        # for data-parallel
+        speech = speech[:, : speech_lengths.max()]
+
+        if self.frontend is not None:
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            feats, feats_lengths = speech, speech_lengths
+
+        return feats, feats_lengths
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        target: torch.Tensor,
+        t_len: torch.Tensor,
+        u_len: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute CTC loss.
+
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            target: Target label ID sequences. (B, L)
+            t_len: Encoder output sequences lengths. (B,)
+            u_len: Target label ID sequences lengths. (B,)
+
+        Return:
+            loss_ctc: CTC loss value.
+
+        """
+        ctc_in = self.ctc_lin(
+            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+        )
+        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+        target_mask = target != 0
+        ctc_target = target[target_mask].cpu()
+
+        with torch.backends.cudnn.flags(deterministic=True):
+            loss_ctc = torch.nn.functional.ctc_loss(
+                ctc_in,
+                ctc_target,
+                t_len,
+                u_len,
+                zero_infinity=True,
+                reduction="sum",
+            )
+        loss_ctc /= target.size(0)
+
+        return loss_ctc
+
+    def _calc_lm_loss(
+        self,
+        decoder_out: torch.Tensor,
+        target: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute LM loss.
+
+        Args:
+            decoder_out: Decoder output sequences. (B, U, D_dec)
+            target: Target label ID sequences. (B, L)
+
+        Return:
+            loss_lm: LM loss value.
+
+        """
+        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+        lm_target = target.view(-1).type(torch.int64)
+
+        with torch.no_grad():
+            true_dist = lm_loss_in.clone()
+            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+            # Ignore blank ID (0)
+            ignore = lm_target == 0
+            lm_target = lm_target.masked_fill(ignore, 0)
+
+            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+        loss_lm = torch.nn.functional.kl_div(
+            torch.log_softmax(lm_loss_in, dim=1),
+            true_dist,
+            reduction="none",
+        )
+        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+            0
+        )
+
+        return loss_lm

+ 0 - 23
funasr/models/e2e_asr_transducer.py

@@ -353,11 +353,6 @@ class TransducerModel(FunASRModel):
         """
         """
         if self.criterion_transducer is None:
         if self.criterion_transducer is None:
             try:
             try:
-                # from warprnnt_pytorch import RNNTLoss
-	        # self.criterion_transducer = RNNTLoss(
-                    # reduction="mean",
-                    # fastemit_lambda=self.fastemit_lambda,
-                # )
                 from warp_rnnt import rnnt_loss as RNNTLoss
                 from warp_rnnt import rnnt_loss as RNNTLoss
                 self.criterion_transducer = RNNTLoss
                 self.criterion_transducer = RNNTLoss
 
 
@@ -368,12 +363,6 @@ class TransducerModel(FunASRModel):
                 )
                 )
                 exit(1)
                 exit(1)
 
 
-        # loss_transducer = self.criterion_transducer(
-        #     joint_out,
-        #     target,
-        #     t_len,
-        #     u_len,
-        # )
         log_probs = torch.log_softmax(joint_out, dim=-1)
         log_probs = torch.log_softmax(joint_out, dim=-1)
 
 
         loss_transducer = self.criterion_transducer(
         loss_transducer = self.criterion_transducer(
@@ -637,7 +626,6 @@ class UnifiedTransducerModel(FunASRModel):
 
 
         batch_size = speech.shape[0]
         batch_size = speech.shape[0]
         text = text[:, : text_lengths.max()]
         text = text[:, : text_lengths.max()]
-        #print(speech.shape)
         # 1. Encoder
         # 1. Encoder
         encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
         encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
 
 
@@ -854,11 +842,6 @@ class UnifiedTransducerModel(FunASRModel):
         """
         """
         if self.criterion_transducer is None:
         if self.criterion_transducer is None:
             try:
             try:
-                # from warprnnt_pytorch import RNNTLoss
-            # self.criterion_transducer = RNNTLoss(
-                    # reduction="mean",
-                    # fastemit_lambda=self.fastemit_lambda,
-                # )
                 from warp_rnnt import rnnt_loss as RNNTLoss
                 from warp_rnnt import rnnt_loss as RNNTLoss
                 self.criterion_transducer = RNNTLoss
                 self.criterion_transducer = RNNTLoss
 
 
@@ -869,12 +852,6 @@ class UnifiedTransducerModel(FunASRModel):
                 )
                 )
                 exit(1)
                 exit(1)
 
 
-        # loss_transducer = self.criterion_transducer(
-        #     joint_out,
-        #     target,
-        #     t_len,
-        #     u_len,
-        # )
         log_probs = torch.log_softmax(joint_out, dim=-1)
         log_probs = torch.log_softmax(joint_out, dim=-1)
 
 
         loss_transducer = self.criterion_transducer(
         loss_transducer = self.criterion_transducer(

+ 127 - 0
funasr/models/predictor/cif.py

@@ -1,10 +1,12 @@
 import torch
 import torch
 from torch import nn
 from torch import nn
+from torch import Tensor
 import logging
 import logging
 import numpy as np
 import numpy as np
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.device_funcs import to_device
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.modules.streaming_utils.utils import sequence_mask
 from funasr.modules.streaming_utils.utils import sequence_mask
+from typing import Optional, Tuple
 
 
 class CifPredictor(nn.Module):
 class CifPredictor(nn.Module):
     def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
     def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
@@ -747,3 +749,128 @@ class CifPredictorV3(nn.Module):
         predictor_alignments = index_div_bool_zeros_count_tile_out
         predictor_alignments = index_div_bool_zeros_count_tile_out
         predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
         predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
         return predictor_alignments.detach(), predictor_alignments_length.detach()
         return predictor_alignments.detach(), predictor_alignments_length.detach()
+
+class BATPredictor(nn.Module):
+    def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
+        super(BATPredictor, self).__init__()
+
+        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
+        self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+        self.cif_output = nn.Linear(idim, 1)
+        self.dropout = torch.nn.Dropout(p=dropout)
+        self.threshold = threshold
+        self.smooth_factor = smooth_factor
+        self.noise_threshold = noise_threshold
+        self.return_accum = return_accum
+
+    def cif(
+        self,
+        input: Tensor,
+        alpha: Tensor,
+        beta: float = 1.0,
+        return_accum: bool = False,
+    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+        B, S, C = input.size()
+        assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
+
+        dtype = alpha.dtype
+        alpha = alpha.float()
+
+        alpha_sum = alpha.sum(1)
+        feat_lengths = (alpha_sum / beta).floor().long()
+        T = feat_lengths.max()
+
+        # aggregate and integrate
+        csum = alpha.cumsum(-1)
+        with torch.no_grad():
+            # indices used for scattering
+            right_idx = (csum / beta).floor().long().clip(max=T)
+            left_idx = right_idx.roll(1, dims=1)
+            left_idx[:, 0] = 0
+
+            # count # of fires from each source
+            fire_num = right_idx - left_idx
+            extra_weights = (fire_num - 1).clip(min=0)
+            # The extra entry in last dim is for
+            output = input.new_zeros((B, T + 1, C))
+            source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
+            zero = alpha.new_zeros((1,))
+
+        # right scatter
+        fire_mask = fire_num > 0
+        right_weight = torch.where(
+            fire_mask,
+            csum - right_idx.type_as(alpha) * beta,
+            zero
+        ).type_as(input)
+        # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
+        output.scatter_add_(
+            1,
+            right_idx.unsqueeze(-1).expand(-1, -1, C),
+            right_weight.unsqueeze(-1) * input
+        )
+
+        # left scatter
+        left_weight = (
+            alpha - right_weight - extra_weights.type_as(alpha) * beta
+        ).type_as(input)
+        output.scatter_add_(
+            1,
+            left_idx.unsqueeze(-1).expand(-1, -1, C),
+            left_weight.unsqueeze(-1) * input
+        )
+
+         # extra scatters
+        if extra_weights.ge(0).any():
+            extra_steps = extra_weights.max().item()
+            tgt_idx = left_idx
+            src_feats = input * beta
+            for _ in range(extra_steps):
+                tgt_idx = (tgt_idx + 1).clip(max=T)
+                # (B, S, 1)
+                src_mask = (extra_weights > 0)
+                output.scatter_add_(
+                    1,
+                    tgt_idx.unsqueeze(-1).expand(-1, -1, C),
+                    src_feats * src_mask.unsqueeze(2)
+                )
+                extra_weights -= 1
+
+        output = output[:, :T, :]
+
+        if return_accum:
+            return output, csum
+        else:
+            return output, alpha
+
+    def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
+        h = hidden
+        context = h.transpose(1, 2)
+        queries = self.pad(context)
+        memory = self.cif_conv1d(queries)
+        output = memory + context
+        output = self.dropout(output)
+        output = output.transpose(1, 2)
+        output = torch.relu(output)
+        output = self.cif_output(output)
+        alphas = torch.sigmoid(output)
+        alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
+        if mask is not None:
+            alphas = alphas * mask.transpose(-1, -2).float()
+        if mask_chunk_predictor is not None:
+            alphas = alphas * mask_chunk_predictor
+        alphas = alphas.squeeze(-1)
+        if target_label_length is not None:
+            target_length = target_label_length
+        elif target_label is not None:
+            target_length = (target_label != ignore_id).float().sum(-1)
+            # logging.info("target_length: {}".format(target_length))
+        else:
+            target_length = None
+        token_num = alphas.sum(-1)
+        if target_length is not None:
+            # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
+            # target_length = length_noise + target_length
+            alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
+        acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
+        return acoustic_embeds, token_num, alphas, cif_peak

+ 137 - 1
funasr/tasks/asr.py

@@ -47,6 +47,7 @@ from funasr.models.e2e_asr_mfcca import MFCCA
 from funasr.models.e2e_sa_asr import SAASRModel
 from funasr.models.e2e_sa_asr import SAASRModel
 from funasr.models.e2e_uni_asr import UniASR
 from funasr.models.e2e_uni_asr import UniASR
 from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
 from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
+from funasr.models.e2e_asr_bat import BATModel
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
 from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
 from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
 from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
@@ -66,7 +67,7 @@ from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.postencoder.hugging_face_transformers_postencoder import (
 from funasr.models.postencoder.hugging_face_transformers_postencoder import (
     HuggingFaceTransformersPostEncoder,  # noqa: H301
     HuggingFaceTransformersPostEncoder,  # noqa: H301
 )
 )
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.preencoder.linear import LinearProjection
 from funasr.models.preencoder.linear import LinearProjection
 from funasr.models.preencoder.sinc import LightweightSincConvs
 from funasr.models.preencoder.sinc import LightweightSincConvs
@@ -135,6 +136,7 @@ model_choices = ClassChoices(
         timestamp_prediction=TimestampPredictor,
         timestamp_prediction=TimestampPredictor,
         rnnt=TransducerModel,
         rnnt=TransducerModel,
         rnnt_unified=UnifiedTransducerModel,
         rnnt_unified=UnifiedTransducerModel,
+        bat=BATModel,
         sa_asr=SAASRModel,
         sa_asr=SAASRModel,
     ),
     ),
     type_check=FunASRModel,
     type_check=FunASRModel,
@@ -266,6 +268,7 @@ predictor_choices = ClassChoices(
         ctc_predictor=None,
         ctc_predictor=None,
         cif_predictor_v2=CifPredictorV2,
         cif_predictor_v2=CifPredictorV2,
         cif_predictor_v3=CifPredictorV3,
         cif_predictor_v3=CifPredictorV3,
+        bat_predictor=BATPredictor,
     ),
     ),
     type_check=None,
     type_check=None,
     default="cif_predictor",
     default="cif_predictor",
@@ -1508,6 +1511,139 @@ class ASRTransducerTask(ASRTask):
 
 
         return model
         return model
 
 
+class ASRBATTask(ASRTask):
+    """ASR Boundary Aware Transducer Task definition."""
+
+    num_optimizers: int = 1
+
+    class_choices_list = [
+        model_choices,
+        frontend_choices,
+        specaug_choices,
+        normalize_choices,
+        encoder_choices,
+        rnnt_decoder_choices,
+        joint_network_choices,
+        predictor_choices,
+    ]
+
+    trainer = Trainer
+
+    @classmethod
+    def build_model(cls, args: argparse.Namespace) -> BATModel:
+        """Required data depending on task mode.
+        Args:
+            cls: ASRBATTask object.
+            args: Task arguments.
+        Return:
+            model: ASR BAT model.
+        """
+        assert check_argument_types()
+
+        if isinstance(args.token_list, str):
+            with open(args.token_list, encoding="utf-8") as f:
+                token_list = [line.rstrip() for line in f]
+
+            # Overwriting token_list to keep it as "portable".
+            args.token_list = list(token_list)
+        elif isinstance(args.token_list, (tuple, list)):
+            token_list = list(args.token_list)
+        else:
+            raise RuntimeError("token_list must be str or list")
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size }")
+
+        # 1. frontend
+        if args.input_size is None:
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            frontend = None
+            input_size = args.input_size
+
+        # 2. Data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # 3. Normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(**args.normalize_conf)
+        else:
+            normalize = None
+
+        # 4. Encoder
+        if getattr(args, "encoder", None) is not None:
+            encoder_class = encoder_choices.get_class(args.encoder)
+            encoder = encoder_class(input_size, **args.encoder_conf)
+        else:
+            encoder = Encoder(input_size, **args.encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        # 5. Decoder
+        rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
+        decoder = rnnt_decoder_class(
+            vocab_size,
+            **args.rnnt_decoder_conf,
+        )
+        decoder_output_size = decoder.output_size
+
+        if getattr(args, "decoder", None) is not None:
+            att_decoder_class = decoder_choices.get_class(args.decoder)
+
+            att_decoder = att_decoder_class(
+                vocab_size=vocab_size,
+                encoder_output_size=encoder_output_size,
+                **args.decoder_conf,
+            )
+        else:
+            att_decoder = None
+        # 6. Joint Network
+        joint_network = JointNetwork(
+            vocab_size,
+            encoder_output_size,
+            decoder_output_size,
+            **args.joint_network_conf,
+        )
+
+        predictor_class = predictor_choices.get_class(args.predictor)
+        predictor = predictor_class(**args.predictor_conf)
+
+        # 7. Build model
+        try:
+            model_class = model_choices.get_class(args.model)
+        except AttributeError:
+            model_class = model_choices.get_class("rnnt_unified")
+
+        model = model_class(
+            vocab_size=vocab_size,
+            token_list=token_list,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            encoder=encoder,
+            decoder=decoder,
+            att_decoder=att_decoder,
+            joint_network=joint_network,
+            predictor=predictor,
+            **args.model_conf,
+        )
+        # 8. Initialize model
+        if args.init is not None:
+            raise NotImplementedError(
+                "Currently not supported.",
+                "Initialization part will be reworked in a short future.",
+            )
+
+        #assert check_return_type(model)
+
+        return model
 
 
 class ASRTaskSAASR(ASRTask):
 class ASRTaskSAASR(ASRTask):
     # If you need more than one optimizers, change this value
     # If you need more than one optimizers, change this value