Просмотр исходного кода

Merge pull request #571 from alibaba-damo-academy/dev_lhn

Dev lhn
hnluo 2 лет назад
Родитель
Сommit
4ce2e2d76c

+ 1 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/README.md

@@ -0,0 +1 @@
+../../TEMPLATE/README.md

+ 5 - 32
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo.py

@@ -1,39 +1,12 @@
-import os
-import logging
-import torch
-import soundfile
-
 from modelscope.pipelines import pipeline
 from modelscope.utils.constant import Tasks
-from modelscope.utils.logger import get_logger
-
-logger = get_logger(log_level=logging.CRITICAL)
-logger.setLevel(logging.CRITICAL)
 
-os.environ["MODELSCOPE_CACHE"] = "./"
 inference_pipeline = pipeline(
     task=Tasks.auto_speech_recognition,
     model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
-    model_revision='v1.0.4'
+    model_revision='v1.0.6',
+    mode="paraformer_fake_streaming"
 )
-
-model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online")
-speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
-speech_length = speech.shape[0]
-
-sample_offset = 0
-chunk_size = [5, 10, 5] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
-stride_size =  chunk_size[1] * 960
-param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
-final_result = ""
-
-for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
-    if sample_offset + stride_size >= speech_length - 1:
-        stride_size = speech_length - sample_offset
-        param_dict["is_final"] = True
-    rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
-                                    param_dict=param_dict)
-    if len(rec_result) != 0:
-        final_result += rec_result['text']
-        print(rec_result)
-print(final_result)
+audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav'
+rec_result = inference_pipeline(audio_in=audio_in)
+print(rec_result)

+ 40 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/demo_online.py

@@ -0,0 +1,40 @@
+import os
+import logging
+import torch
+import soundfile
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.utils.logger import get_logger
+
+logger = get_logger(log_level=logging.CRITICAL)
+logger.setLevel(logging.CRITICAL)
+
+os.environ["MODELSCOPE_CACHE"] = "./"
+inference_pipeline = pipeline(
+    task=Tasks.auto_speech_recognition,
+    model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online',
+    model_revision='v1.0.6',
+    mode="paraformer_streaming"
+)
+
+model_dir = os.path.join(os.environ["MODELSCOPE_CACHE"], "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online")
+speech, sample_rate = soundfile.read(os.path.join(model_dir, "example/asr_example.wav"))
+speech_length = speech.shape[0]
+
+sample_offset = 0
+chunk_size = [5, 10, 5] #[5, 10, 5] 600ms, [8, 8, 4] 480ms
+stride_size =  chunk_size[1] * 960
+param_dict = {"cache": dict(), "is_final": False, "chunk_size": chunk_size}
+final_result = ""
+
+for sample_offset in range(0, speech_length, min(stride_size, speech_length - sample_offset)):
+    if sample_offset + stride_size >= speech_length - 1:
+        stride_size = speech_length - sample_offset
+        param_dict["is_final"] = True
+    rec_result = inference_pipeline(audio_in=speech[sample_offset: sample_offset + stride_size],
+                                    param_dict=param_dict)
+    if len(rec_result) != 0:
+        final_result += rec_result['text']
+        print(rec_result)
+print(final_result)

+ 37 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/finetune.py

@@ -0,0 +1,37 @@
+import os
+
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
+
+def modelscope_finetune(params):
+    if not os.path.exists(params.output_dir):
+        os.makedirs(params.output_dir, exist_ok=True)
+    # dataset split ["train", "validation"]
+    ds_dict = MsDataset.load(params.data_path)
+    kwargs = dict(
+        model=params.model,
+        model_revision='v1.0.6',
+        data_dir=ds_dict,
+        dataset_type=params.dataset_type,
+        work_dir=params.output_dir,
+        batch_bins=params.batch_bins,
+        max_epoch=params.max_epoch,
+        lr=params.lr)
+    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+    trainer.train()
+
+
+if __name__ == '__main__':
+    params = modelscope_args(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", data_path="./data")
+    params.output_dir = "./checkpoint"              # m模型保存路径
+    params.data_path = "./example_data/"            # 数据路径
+    params.dataset_type = "small"                   # 小数据量设置small,若数据量大于1000小时,请使用large
+    params.batch_bins = 1000                       # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
+    params.max_epoch = 20                           # 最大训练轮数
+    params.lr = 0.00005                             # 设置学习率
+    
+    modelscope_finetune(params)

+ 32 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.py

@@ -0,0 +1,32 @@
+import os
+import shutil
+import argparse
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+def modelscope_infer(args):
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpuid)
+    inference_pipeline = pipeline(
+        task=Tasks.auto_speech_recognition,
+        model=args.model,
+        output_dir=args.output_dir,
+        batch_size=args.batch_size,
+        model_revision='v1.0.6',
+        mode="paraformer_fake_streaming",
+        param_dict={"decoding_model": args.decoding_mode, "hotword": args.hotword_txt}
+    )
+    inference_pipeline(audio_in=args.audio_in)
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--model', type=str, default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+    parser.add_argument('--audio_in', type=str, default="./data/test/wav.scp")
+    parser.add_argument('--output_dir', type=str, default="./results/")
+    parser.add_argument('--decoding_mode', type=str, default="normal")
+    parser.add_argument('--model_revision', type=str, default=None)
+    parser.add_argument('--mode', type=str, default=None)
+    parser.add_argument('--hotword_txt', type=str, default=None)
+    parser.add_argument('--batch_size', type=int, default=64)
+    parser.add_argument('--gpuid', type=str, default="0")
+    args = parser.parse_args()
+    modelscope_infer(args)

+ 104 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/infer.sh

@@ -0,0 +1,104 @@
+#!/usr/bin/env bash
+
+set -e
+set -u
+set -o pipefail
+
+stage=1
+stop_stage=2
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+data_dir="./data/test"
+output_dir="./results"
+batch_size=32
+gpu_inference=true    # whether to perform gpu decoding
+gpuid_list="0,1"    # set gpus, e.g., gpuid_list="0,1"
+njob=32    # the number of jobs for CPU decoding, if gpu_inference=false, use CPU decoding, please set njob
+checkpoint_dir=
+checkpoint_name="valid.cer_ctc.ave.pb"
+
+. utils/parse_options.sh || exit 1;
+
+if ${gpu_inference} == "true"; then
+    nj=$(echo $gpuid_list | awk -F "," '{print NF}')
+else
+    nj=$njob
+    batch_size=1
+    gpuid_list=""
+    for JOB in $(seq ${nj}); do
+        gpuid_list=$gpuid_list"-1,"
+    done
+fi
+
+mkdir -p $output_dir/split
+split_scps=""
+for JOB in $(seq ${nj}); do
+    split_scps="$split_scps $output_dir/split/wav.$JOB.scp"
+done
+perl utils/split_scp.pl ${data_dir}/wav.scp ${split_scps}
+
+if [ -n "${checkpoint_dir}" ]; then
+  python utils/prepare_checkpoint.py ${model} ${checkpoint_dir} ${checkpoint_name}
+  model=${checkpoint_dir}/${model}
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ];then
+    echo "Decoding ..."
+    gpuid_list_array=(${gpuid_list//,/ })
+    for JOB in $(seq ${nj}); do
+        {
+        id=$((JOB-1))
+        gpuid=${gpuid_list_array[$id]}
+        mkdir -p ${output_dir}/output.$JOB
+        python infer.py \
+            --model ${model} \
+            --audio_in ${output_dir}/split/wav.$JOB.scp \
+            --output_dir ${output_dir}/output.$JOB \
+            --batch_size ${batch_size} \
+            --gpuid ${gpuid}
+            --mode "paraformer_fake_streaming"
+        }&
+    done
+    wait
+
+    mkdir -p ${output_dir}/1best_recog
+    for f in token score text; do
+        if [ -f "${output_dir}/output.1/1best_recog/${f}" ]; then
+          for i in $(seq "${nj}"); do
+              cat "${output_dir}/output.${i}/1best_recog/${f}"
+          done | sort -k1 >"${output_dir}/1best_recog/${f}"
+        fi
+    done
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ];then
+    echo "Computing WER ..."
+    cp ${output_dir}/1best_recog/text ${output_dir}/1best_recog/text.proc
+    cp ${data_dir}/text ${output_dir}/1best_recog/text.ref
+    python utils/compute_wer.py ${output_dir}/1best_recog/text.ref ${output_dir}/1best_recog/text.proc ${output_dir}/1best_recog/text.cer
+    tail -n 3 ${output_dir}/1best_recog/text.cer
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ];then
+    echo "SpeechIO TIOBE textnorm"
+    echo "$0 --> Normalizing REF text ..."
+    ./utils/textnorm_zh.py \
+        --has_key --to_upper \
+        ${data_dir}/text \
+        ${output_dir}/1best_recog/ref.txt
+
+    echo "$0 --> Normalizing HYP text ..."
+    ./utils/textnorm_zh.py \
+        --has_key --to_upper \
+        ${output_dir}/1best_recog/text.proc \
+        ${output_dir}/1best_recog/rec.txt
+    grep -v $'\t$' ${output_dir}/1best_recog/rec.txt > ${output_dir}/1best_recog/rec_non_empty.txt
+
+    echo "$0 --> computing WER/CER and alignment ..."
+    ./utils/error_rate_zh \
+        --tokenizer char \
+        --ref ${output_dir}/1best_recog/ref.txt \
+        --hyp ${output_dir}/1best_recog/rec_non_empty.txt \
+        ${output_dir}/1best_recog/DETAILS.txt | tee ${output_dir}/1best_recog/RESULTS.txt
+    rm -rf ${output_dir}/1best_recog/rec.txt ${output_dir}/1best_recog/rec_non_empty.txt
+fi
+

+ 1 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/utils

@@ -0,0 +1 @@
+../../TEMPLATE/utils/

+ 3 - 1
funasr/bin/asr_infer.py

@@ -305,6 +305,7 @@ class Speech2TextParaformer:
             nbest: int = 1,
             frontend_conf: dict = None,
             hotword_list_or_file: str = None,
+            decoding_ind: int = 0,
             **kwargs,
     ):
         assert check_argument_types()
@@ -415,6 +416,7 @@ class Speech2TextParaformer:
         self.nbest = nbest
         self.frontend = frontend
         self.encoder_downsampling_factor = 1
+        self.decoding_ind = decoding_ind
         if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
             self.encoder_downsampling_factor = 4
 
@@ -452,7 +454,7 @@ class Speech2TextParaformer:
         batch = to_device(batch, device=self.device)
 
         # b. Forward Encoder
-        enc, enc_len = self.asr_model.encode(**batch)
+        enc, enc_len = self.asr_model.encode(**batch, ind=self.decoding_ind)
         if isinstance(enc, tuple):
             enc = enc[0]
         # assert len(enc) == 1, len(enc)

+ 3 - 1
funasr/bin/asr_inference_launch.py

@@ -1638,6 +1638,8 @@ def inference_launch(**kwargs):
         return inference_uniasr(**kwargs)
     elif mode == "paraformer":
         return inference_paraformer(**kwargs)
+    elif mode == "paraformer_fake_streaming":
+        return inference_paraformer(**kwargs)
     elif mode == "paraformer_streaming":
         return inference_paraformer_online(**kwargs)
     elif mode.startswith("paraformer_vad"):
@@ -1920,4 +1922,4 @@ def main(cmd=None):
 
 
 if __name__ == "__main__":
-    main()
+    main()

+ 2 - 0
funasr/bin/build_trainer.py

@@ -23,6 +23,8 @@ def parse_args(mode):
         from funasr.tasks.asr import ASRTask as ASRTask
     elif mode == "paraformer":
         from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+    elif mode == "paraformer_streaming":
+        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
     elif mode == "paraformer_vad_punc":
         from funasr.tasks.asr import ASRTaskParaformer as ASRTask
     elif mode == "uniasr":

+ 3 - 2
funasr/build_utils/build_asr_model.py

@@ -23,7 +23,7 @@ from funasr.models.decoder.rnnt_decoder import RNNTDecoder
 from funasr.models.joint_net.joint_network import JointNetwork
 from funasr.models.e2e_asr import ASRModel
 from funasr.models.e2e_asr_mfcca import MFCCA
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_tp import TimestampPredictor
 from funasr.models.e2e_uni_asr import UniASR
 from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
@@ -82,6 +82,7 @@ model_choices = ClassChoices(
         asr=ASRModel,
         uniasr=UniASR,
         paraformer=Paraformer,
+        paraformer_online=ParaformerOnline,
         paraformer_bert=ParaformerBert,
         bicif_paraformer=BiCifParaformer,
         contextual_paraformer=ContextualParaformer,
@@ -293,7 +294,7 @@ def build_asr_model(args):
             token_list=token_list,
             **args.model_conf,
         )
-    elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
+    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
         # predictor
         predictor_class = predictor_choices.get_class(args.predictor)
         predictor = predictor_class(**args.predictor_conf)

+ 6 - 1
funasr/models/decoder/sanm_decoder.py

@@ -935,6 +935,7 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
         hlens: torch.Tensor,
         ys_in_pad: torch.Tensor,
         ys_in_lens: torch.Tensor,
+        chunk_mask: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -955,9 +956,13 @@ class ParaformerSANMDecoder(BaseTransformerDecoder):
         """
         tgt = ys_in_pad
         tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
+        
         memory = hs_pad
         memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        if chunk_mask is not None:
+            memory_mask = memory_mask * chunk_mask
+            if tgt_mask.size(1) != memory_mask.size(1):
+                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
 
         x = tgt
         x, tgt_mask, memory, memory_mask, _ = self.decoders(

+ 468 - 17
funasr/models/e2e_asr_paraformer.py

@@ -153,6 +153,7 @@ class Paraformer(FunASRModel):
             speech_lengths: torch.Tensor,
             text: torch.Tensor,
             text_lengths: torch.Tensor,
+            decoding_ind: int = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
         Args:
@@ -160,6 +161,7 @@ class Paraformer(FunASRModel):
                 speech_lengths: (Batch, )
                 text: (Batch, Length)
                 text_lengths: (Batch,)
+                decoding_ind: int
         """
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
@@ -176,7 +178,11 @@ class Paraformer(FunASRModel):
         speech = speech[:, :speech_lengths.max()]
 
         # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if hasattr(self.encoder, "overlap_chunk_cls"):
+            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+        else:
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -272,12 +278,13 @@ class Paraformer(FunASRModel):
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
+                ind: int
         """
         with autocast(False):
             # 1. Extract feats
@@ -299,11 +306,25 @@ class Paraformer(FunASRModel):
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
         if self.encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder(
-                feats, feats_lengths, ctc=self.ctc
-            )
+            if hasattr(self.encoder, "overlap_chunk_cls"):
+                encoder_out, encoder_out_lens, _ = self.encoder(
+                    feats, feats_lengths, ctc=self.ctc, ind=ind
+                )
+                encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+                                                                                            encoder_out_lens,
+                                                                                            chunk_outs=None)
+            else:
+                encoder_out, encoder_out_lens, _ = self.encoder(
+                    feats, feats_lengths, ctc=self.ctc
+                )
         else:
-            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+            if hasattr(self.encoder, "overlap_chunk_cls"):
+                encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
+                encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+                                                                                            encoder_out_lens,
+                                                                                            chunk_outs=None)
+            else:
+                encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -592,9 +613,137 @@ class ParaformerOnline(Paraformer):
     """
 
     def __init__(
-            self, *args, **kwargs,
+            self,
+            vocab_size: int,
+            token_list: Union[Tuple[str, ...], List[str]],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            encoder: AbsEncoder,
+            decoder: AbsDecoder,
+            ctc: CTC,
+            ctc_weight: float = 0.5,
+            interctc_weight: float = 0.0,
+            ignore_id: int = -1,
+            blank_id: int = 0,
+            sos: int = 1,
+            eos: int = 2,
+            lsm_weight: float = 0.0,
+            length_normalized_loss: bool = False,
+            report_cer: bool = True,
+            report_wer: bool = True,
+            sym_space: str = "<space>",
+            sym_blank: str = "<blank>",
+            extract_feats_in_collect_stats: bool = True,
+            predictor=None,
+            predictor_weight: float = 0.0,
+            predictor_bias: int = 0,
+            sampling_ratio: float = 0.2,
+            decoder_attention_chunk_type: str = 'chunk',
+            share_embedding: bool = False,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
+            use_1st_decoder_loss: bool = False,
     ):
-        super().__init__(*args, **kwargs)
+        assert check_argument_types()
+        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+        assert 0.0 <= interctc_weight < 1.0, interctc_weight
+
+        super().__init__(
+            vocab_size=vocab_size,
+            token_list=token_list,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            encoder=encoder,
+            postencoder=postencoder,
+            decoder=decoder,
+            ctc=ctc,
+            ctc_weight=ctc_weight,
+            interctc_weight=interctc_weight,
+            ignore_id=ignore_id,
+            blank_id=blank_id,
+            sos=sos,
+            eos=eos,
+            lsm_weight=lsm_weight,
+            length_normalized_loss=length_normalized_loss,
+            report_cer=report_cer,
+            report_wer=report_wer,
+            sym_space=sym_space,
+            sym_blank=sym_blank,
+            extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+            predictor=predictor,
+            predictor_weight=predictor_weight,
+            predictor_bias=predictor_bias,
+            sampling_ratio=sampling_ratio,
+        )
+        # note that eos is the same as sos (equivalent ID)
+        self.blank_id = blank_id
+        self.sos = vocab_size - 1 if sos is None else sos
+        self.eos = vocab_size - 1 if eos is None else eos
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.ctc_weight = ctc_weight
+        self.interctc_weight = interctc_weight
+        self.token_list = token_list.copy()
+
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+        self.preencoder = preencoder
+        self.postencoder = postencoder
+        self.encoder = encoder
+
+        if not hasattr(self.encoder, "interctc_use_conditioning"):
+            self.encoder.interctc_use_conditioning = False
+        if self.encoder.interctc_use_conditioning:
+            self.encoder.conditioning_layer = torch.nn.Linear(
+                vocab_size, self.encoder.output_size()
+            )
+
+        self.error_calculator = None
+
+        if ctc_weight == 1.0:
+            self.decoder = None
+        else:
+            self.decoder = decoder
+
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        if report_cer or report_wer:
+            self.error_calculator = ErrorCalculator(
+                token_list, sym_space, sym_blank, report_cer, report_wer
+            )
+
+        if ctc_weight == 0.0:
+            self.ctc = None
+        else:
+            self.ctc = ctc
+
+        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+        self.predictor = predictor
+        self.predictor_weight = predictor_weight
+        self.predictor_bias = predictor_bias
+        self.sampling_ratio = sampling_ratio
+        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+        self.step_cur = 0
+        self.scama_mask = None
+        if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
+            from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
+            self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
+            self.decoder_attention_chunk_type = decoder_attention_chunk_type
+
+        self.share_embedding = share_embedding
+        if self.share_embedding:
+            self.decoder.embed = None
+
+        self.use_1st_decoder_loss = use_1st_decoder_loss
 
     def forward(
             self,
@@ -602,6 +751,7 @@ class ParaformerOnline(Paraformer):
             speech_lengths: torch.Tensor,
             text: torch.Tensor,
             text_lengths: torch.Tensor,
+            decoding_ind: int = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
         Args:
@@ -609,6 +759,7 @@ class ParaformerOnline(Paraformer):
                 speech_lengths: (Batch, )
                 text: (Batch, Length)
                 text_lengths: (Batch,)
+                decoding_ind: int
         """
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
@@ -625,7 +776,11 @@ class ParaformerOnline(Paraformer):
         speech = speech[:, :speech_lengths.max()]
 
         # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if hasattr(self.encoder, "overlap_chunk_cls"):
+            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+        else:
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
         intermediate_outs = None
         if isinstance(encoder_out, tuple):
             intermediate_outs = encoder_out[1]
@@ -638,8 +793,12 @@ class ParaformerOnline(Paraformer):
 
         # 1. CTC branch
         if self.ctc_weight != 0.0:
+            if hasattr(self.encoder, "overlap_chunk_cls"):
+                encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+                                                                                                    encoder_out_lens,
+                                                                                                    chunk_outs=None)
             loss_ctc, cer_ctc = self._calc_ctc_loss(
-                encoder_out, encoder_out_lens, text, text_lengths
+                encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
             )
 
             # Collect CTC branch stats
@@ -652,8 +811,14 @@ class ParaformerOnline(Paraformer):
             for layer_idx, intermediate_out in intermediate_outs:
                 # we assume intermediate_out has the same length & padding
                 # as those of encoder_out
+                if hasattr(self.encoder, "overlap_chunk_cls"):
+                    encoder_out_ctc, encoder_out_lens_ctc = \
+                        self.encoder.overlap_chunk_cls.remove_chunk(
+                            intermediate_out,
+                            encoder_out_lens,
+                            chunk_outs=None)
                 loss_ic, cer_ic = self._calc_ctc_loss(
-                    intermediate_out, encoder_out_lens, text, text_lengths
+                    encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
                 )
                 loss_interctc = loss_interctc + loss_ic
 
@@ -672,7 +837,7 @@ class ParaformerOnline(Paraformer):
 
         # 2b. Attention decoder branch
         if self.ctc_weight != 1.0:
-            loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+            loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
                 encoder_out, encoder_out_lens, text, text_lengths
             )
 
@@ -684,8 +849,12 @@ class ParaformerOnline(Paraformer):
         else:
             loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
 
+        if self.use_1st_decoder_loss and pre_loss_att is not None:
+            loss = loss + pre_loss_att
+
         # Collect Attn branch stats
         stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+        stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
         stats["acc"] = acc_att
         stats["cer"] = cer_att
         stats["wer"] = wer_att
@@ -697,14 +866,67 @@ class ParaformerOnline(Paraformer):
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
 
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+                        speech: (Batch, Length, ...)
+                        speech_lengths: (Batch, )
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+        
+        # 4. Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.encoder(
+                feats, feats_lengths, ctc=self.ctc, ind=ind
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+
+        # Post-encoder, e.g. NLU
+        if self.postencoder is not None:
+            encoder_out, encoder_out_lens = self.postencoder(
+                encoder_out, encoder_out_lens
+            )
+
+        assert encoder_out.size(0) == speech.size(0), (
+            encoder_out.size(),
+            speech.size(0),
+        )
+        assert encoder_out.size(1) <= encoder_out_lens.max(), (
+            encoder_out.size(),
+            encoder_out_lens.max(),
+        )
+
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+
+        return encoder_out, encoder_out_lens
+
     def encode_chunk(
             self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-<<<<<<< HEAD
-=======
-
->>>>>>> 4cd79db451786548d8a100f25c3b03da0eb30f4b
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -750,12 +972,241 @@ class ParaformerOnline(Paraformer):
 
         return encoder_out, torch.tensor([encoder_out.size(1)])
 
+    def _calc_att_predictor_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+            encoder_out.device)
+        if self.predictor_bias == 1:
+            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+            ys_pad_lens = ys_pad_lens + self.predictor_bias
+        mask_chunk_predictor = None
+        if self.encoder.overlap_chunk_cls is not None:
+            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+                                                                                           device=encoder_out.device,
+                                                                                           batch_size=encoder_out.size(
+                                                                                               0))
+            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+                                                                                   batch_size=encoder_out.size(0))
+            encoder_out = encoder_out * mask_shfit_chunk
+        pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
+                                                                              ys_pad,
+                                                                              encoder_out_mask,
+                                                                              ignore_id=self.ignore_id,
+                                                                              mask_chunk_predictor=mask_chunk_predictor,
+                                                                              target_label_length=ys_pad_lens,
+                                                                              )
+        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+                                                                                             encoder_out_lens)
+
+        scama_mask = None
+        if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+            encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+            attention_chunk_center_bias = 0
+            attention_chunk_size = encoder_chunk_size
+            decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
+                get_mask_shift_att_chunk_decoder(None,
+                                                 device=encoder_out.device,
+                                                 batch_size=encoder_out.size(0)
+                                                 )
+            scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+                predictor_alignments=predictor_alignments,
+                encoder_sequence_length=encoder_out_lens,
+                chunk_size=1,
+                encoder_chunk_size=encoder_chunk_size,
+                attention_chunk_center_bias=attention_chunk_center_bias,
+                attention_chunk_size=attention_chunk_size,
+                attention_chunk_type=self.decoder_attention_chunk_type,
+                step=None,
+                predictor_mask_chunk_hopping=mask_chunk_predictor,
+                decoder_att_look_back_factor=decoder_att_look_back_factor,
+                mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+                target_length=ys_pad_lens,
+                is_training=self.training,
+            )
+        elif self.encoder.overlap_chunk_cls is not None:
+            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+                                                                                        encoder_out_lens,
+                                                                                        chunk_outs=None)
+        # 0. sampler
+        decoder_out_1st = None
+        pre_loss_att = None
+        if self.sampling_ratio > 0.0:
+            if self.step_cur < 2:
+                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+            if self.use_1st_decoder_loss:
+                sematic_embeds, decoder_out_1st, pre_loss_att = \
+                    self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
+                                           ys_pad_lens, pre_acoustic_embeds, scama_mask)
+            else:
+                sematic_embeds, decoder_out_1st = \
+                    self.sampler(encoder_out, encoder_out_lens, ys_pad,
+                                 ys_pad_lens, pre_acoustic_embeds, scama_mask)
+        else:
+            if self.step_cur < 2:
+                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+            sematic_embeds = pre_acoustic_embeds
+
+        # 1. Forward decoder
+        decoder_outs = self.decoder(
+            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
+        )
+        decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+        if decoder_out_1st is None:
+            decoder_out_1st = decoder_out
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_pad)
+        acc_att = th_accuracy(
+            decoder_out_1st.view(-1, self.vocab_size),
+            ys_pad,
+            ignore_label=self.ignore_id,
+        )
+        loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+        # Compute cer/wer using attention-decoder
+        if self.training or self.error_calculator is None:
+            cer_att, wer_att = None, None
+        else:
+            ys_hat = decoder_out_1st.argmax(dim=-1)
+            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+        return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+
+    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
+
+        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+        if self.share_embedding:
+            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+        else:
+            ys_pad_embed = self.decoder.embed(ys_pad_masked)
+        with torch.no_grad():
+            decoder_outs = self.decoder(
+                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
+            )
+            decoder_out, _ = decoder_outs[0], decoder_outs[1]
+            pred_tokens = decoder_out.argmax(-1)
+            nonpad_positions = ys_pad.ne(self.ignore_id)
+            seq_lens = (nonpad_positions).sum(1)
+            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+            input_mask = torch.ones_like(nonpad_positions)
+            bsz, seq_len = ys_pad.size()
+            for li in range(bsz):
+                target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+                if target_num > 0:
+                    input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+            input_mask = input_mask.eq(1)
+            input_mask = input_mask.masked_fill(~nonpad_positions, False)
+            input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+        sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+            input_mask_expand_dim, 0)
+        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+    def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
+        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+        if self.share_embedding:
+            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+        else:
+            ys_pad_embed = self.decoder.embed(ys_pad_masked)
+        decoder_outs = self.decoder(
+            encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
+        )
+        pre_loss_att = self.criterion_att(decoder_outs[0], ys_pad)
+        decoder_out, _ = decoder_outs[0], decoder_outs[1]
+        pred_tokens = decoder_out.argmax(-1)
+        nonpad_positions = ys_pad.ne(self.ignore_id)
+        seq_lens = (nonpad_positions).sum(1)
+        same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+        input_mask = torch.ones_like(nonpad_positions)
+        bsz, seq_len = ys_pad.size()
+        for li in range(bsz):
+            target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+            if target_num > 0:
+                input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+        input_mask = input_mask.eq(1)
+        input_mask = input_mask.masked_fill(~nonpad_positions, False)
+        input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+        sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+            input_mask_expand_dim, 0)
+
+        return sematic_embeds * tgt_mask, decoder_out * tgt_mask, pre_loss_att
+
+    def calc_predictor(self, encoder_out, encoder_out_lens):
+
+        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+            encoder_out.device)
+        mask_chunk_predictor = None
+        if self.encoder.overlap_chunk_cls is not None:
+            mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+                                                                                           device=encoder_out.device,
+                                                                                           batch_size=encoder_out.size(
+                                                                                               0))
+            mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+                                                                                   batch_size=encoder_out.size(0))
+            encoder_out = encoder_out * mask_shfit_chunk
+        pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+                                                                                           None,
+                                                                                           encoder_out_mask,
+                                                                                           ignore_id=self.ignore_id,
+                                                                                           mask_chunk_predictor=mask_chunk_predictor,
+                                                                                           target_label_length=None,
+                                                                                           )
+        predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+                                                                                             encoder_out_lens+1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
+
+        scama_mask = None
+        if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+            encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+            attention_chunk_center_bias = 0
+            attention_chunk_size = encoder_chunk_size
+            decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+            mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls.\
+                get_mask_shift_att_chunk_decoder(None,
+                                                 device=encoder_out.device,
+                                                 batch_size=encoder_out.size(0)
+                                                 )
+            scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+                predictor_alignments=predictor_alignments,
+                encoder_sequence_length=encoder_out_lens,
+                chunk_size=1,
+                encoder_chunk_size=encoder_chunk_size,
+                attention_chunk_center_bias=attention_chunk_center_bias,
+                attention_chunk_size=attention_chunk_size,
+                attention_chunk_type=self.decoder_attention_chunk_type,
+                step=None,
+                predictor_mask_chunk_hopping=mask_chunk_predictor,
+                decoder_att_look_back_factor=decoder_att_look_back_factor,
+                mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+                target_length=None,
+                is_training=self.training,
+            )
+        self.scama_mask = scama_mask
+
+        return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+
     def calc_predictor_chunk(self, encoder_out, cache=None):
 
         pre_acoustic_embeds, pre_token_length = \
             self.predictor.forward_chunk(encoder_out, cache["encoder"])
         return pre_acoustic_embeds, pre_token_length
 
+    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+        decoder_outs = self.decoder(
+            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+        )
+        decoder_out = decoder_outs[0]
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        return decoder_out, ys_pad_lens
+
     def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
         decoder_outs = self.decoder.forward_chunk(
             encoder_out, sematic_embeds, cache["decoder"]
@@ -1800,4 +2251,4 @@ class ContextualParaformer(Paraformer):
                     "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                   var_dict_tf[name_tf].shape))
 
-        return var_dict_torch_update
+        return var_dict_torch_update

+ 55 - 0
funasr/models/encoder/sanm_encoder.py

@@ -633,6 +633,8 @@ class SANMEncoderChunkOpt(AbsEncoder):
                 self.embed = torch.nn.Linear(input_size, output_size)
         elif input_layer == "pe":
             self.embed = SinusoidalPositionEncoder()
+        elif input_layer == "pe_online":
+            self.embed = StreamSinusoidalPositionEncoder()
         else:
             raise ValueError("unknown input_layer: " + input_layer)
         self.normalize_before = normalize_before
@@ -818,6 +820,59 @@ class SANMEncoderChunkOpt(AbsEncoder):
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
 
+    def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+        if len(cache) == 0:
+            return feats
+        cache["feats"] = to_device(cache["feats"], device=feats.device)
+        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        return overlap_feats
+
+    def forward_chunk(self,
+                      xs_pad: torch.Tensor,
+                      ilens: torch.Tensor,
+                      cache: dict = None,
+                      ctc: CTC = None,
+                      ):
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        else:
+            xs_pad = self.embed(xs_pad, cache)
+        if cache["tail_chunk"]:
+            xs_pad = to_device(cache["feats"], device=xs_pad.device)
+        else:
+            xs_pad = self._add_overlap_chunk(xs_pad, cache)
+        encoder_outs = self.encoders0(xs_pad, None, None, None, None)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            encoder_outs = self.encoders(xs_pad, None, None, None, None)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                encoder_outs = encoder_layer(xs_pad, None, None, None, None)
+                xs_pad, masks = encoder_outs[0], encoder_outs[1]
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), None, None
+        return xs_pad, ilens, None
+
     def gen_tf2torch_map_dict(self):
         tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
         tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf