瀏覽代碼

Funasr1.0 (#1297)

* fix add_file bug (#1296)

Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* funasr1.0 uniasr

* funasr1.0 uniasr

---------

Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
zhifu gao 1 年之前
父節點
當前提交
e4035edb46

+ 29 - 0
examples/industrial_data_pretraining/uniasr/demo.py

@@ -0,0 +1,29 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online", model_revision="v2.0.4",
+                  # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                  # vad_model_revision="v2.0.4",
+                  # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+                  # punc_model_revision="v2.0.4",
+                  )
+
+res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
+print(res)
+
+
+''' can not use currently
+from funasr import AutoFrontend
+
+frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
+
+fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
+
+for batch_idx, fbank_dict in enumerate(fbanks):
+    res = model.generate(**fbank_dict)
+    print(res)
+'''

+ 11 - 0
examples/industrial_data_pretraining/uniasr/infer.sh

@@ -0,0 +1,11 @@
+
+model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model_revision="v2.0.4"
+
+python funasr/bin/inference.py \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
++output_dir="./outputs/debug" \
++device="cpu" \
+

+ 1 - 1
funasr/models/transformer/model.py

@@ -348,7 +348,7 @@ class Transformer(nn.Module):
         scorers["ngram"] = ngram
         
         weights = dict(
-            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
             ctc=kwargs.get("decoding_ctc_weight", 0.0),
             lm=kwargs.get("lm_weight", 0.0),
             ngram=kwargs.get("ngram_weight", 0.0),

+ 496 - 0
funasr/models/uniasr/beam_search.py

@@ -0,0 +1,496 @@
+"""Beam search module."""
+
+from itertools import chain
+import logging
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import torch
+
+from funasr.metrics.common import end_detect
+from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
+from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
+
+
+class Hypothesis(NamedTuple):
+    """Hypothesis data type."""
+
+    yseq: torch.Tensor
+    score: Union[float, torch.Tensor] = 0
+    scores: Dict[str, Union[float, torch.Tensor]] = dict()
+    states: Dict[str, Any] = dict()
+
+    def asdict(self) -> dict:
+        """Convert data to JSON-friendly dict."""
+        return self._replace(
+            yseq=self.yseq.tolist(),
+            score=float(self.score),
+            scores={k: float(v) for k, v in self.scores.items()},
+        )._asdict()
+
+
+
+class BeamSearchScama(torch.nn.Module):
+    """Beam search implementation."""
+
+    def __init__(
+        self,
+        scorers: Dict[str, ScorerInterface],
+        weights: Dict[str, float],
+        beam_size: int,
+        vocab_size: int,
+        sos: int,
+        eos: int,
+        token_list: List[str] = None,
+        pre_beam_ratio: float = 1.5,
+        pre_beam_score_key: str = None,
+    ):
+        """Initialize beam search.
+
+        Args:
+            scorers (dict[str, ScorerInterface]): Dict of decoder modules
+                e.g., Decoder, CTCPrefixScorer, LM
+                The scorer will be ignored if it is `None`
+            weights (dict[str, float]): Dict of weights for each scorers
+                The scorer will be ignored if its weight is 0
+            beam_size (int): The number of hypotheses kept during search
+            vocab_size (int): The number of vocabulary
+            sos (int): Start of sequence id
+            eos (int): End of sequence id
+            token_list (list[str]): List of tokens for debug log
+            pre_beam_score_key (str): key of scores to perform pre-beam search
+            pre_beam_ratio (float): beam size in the pre-beam search
+                will be `int(pre_beam_ratio * beam_size)`
+
+        """
+        super().__init__()
+        # set scorers
+        self.weights = weights
+        self.scorers = dict()
+        self.full_scorers = dict()
+        self.part_scorers = dict()
+        # this module dict is required for recursive cast
+        # `self.to(device, dtype)` in `recog.py`
+        self.nn_dict = torch.nn.ModuleDict()
+        for k, v in scorers.items():
+            w = weights.get(k, 0)
+            if w == 0 or v is None:
+                continue
+            assert isinstance(
+                v, ScorerInterface
+            ), f"{k} ({type(v)}) does not implement ScorerInterface"
+            self.scorers[k] = v
+            if isinstance(v, PartialScorerInterface):
+                self.part_scorers[k] = v
+            else:
+                self.full_scorers[k] = v
+            if isinstance(v, torch.nn.Module):
+                self.nn_dict[k] = v
+
+        # set configurations
+        self.sos = sos
+        self.eos = eos
+        self.token_list = token_list
+        self.pre_beam_size = int(pre_beam_ratio * beam_size)
+        self.beam_size = beam_size
+        self.n_vocab = vocab_size
+        if (
+            pre_beam_score_key is not None
+            and pre_beam_score_key != "full"
+            and pre_beam_score_key not in self.full_scorers
+        ):
+            raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+        self.pre_beam_score_key = pre_beam_score_key
+        self.do_pre_beam = (
+            self.pre_beam_score_key is not None
+            and self.pre_beam_size < self.n_vocab
+            and len(self.part_scorers) > 0
+        )
+
+    def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+        """Get an initial hypothesis data.
+
+        Args:
+            x (torch.Tensor): The encoder output feature
+
+        Returns:
+            Hypothesis: The initial hypothesis.
+
+        """
+        init_states = dict()
+        init_scores = dict()
+        for k, d in self.scorers.items():
+            init_states[k] = d.init_state(x)
+            init_scores[k] = 0.0
+        return [
+            Hypothesis(
+                score=0.0,
+                scores=init_scores,
+                states=init_states,
+                yseq=torch.tensor([self.sos], device=x.device),
+            )
+        ]
+
+    @staticmethod
+    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+        """Append new token to prefix tokens.
+
+        Args:
+            xs (torch.Tensor): The prefix token
+            x (int): The new token to append
+
+        Returns:
+            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+        """
+        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+        return torch.cat((xs, x))
+
+    def score_full(
+        self, hyp: Hypothesis,
+        x: torch.Tensor,
+        x_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.full_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.full_scorers`
+                and tensor score values of shape: `(self.n_vocab,)`,
+                and state dict that has string keys
+                and state values of `self.full_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.full_scorers.items():
+            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
+        return scores, states
+
+    def score_partial(
+        self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.part_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            ids (torch.Tensor): 1D tensor of new partial tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.part_scorers`
+                and tensor score values of shape: `(len(ids),)`,
+                and state dict that has string keys
+                and state values of `self.part_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.part_scorers.items():
+            scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
+        return scores, states
+
+    def beam(
+        self, weighted_scores: torch.Tensor, ids: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute topk full token ids and partial token ids.
+
+        Args:
+            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+            Its shape is `(self.n_vocab,)`.
+            ids (torch.Tensor): The partial token ids to compute topk
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]:
+                The topk full token ids and partial token ids.
+                Their shapes are `(self.beam_size,)`
+
+        """
+        # no pre beam performed
+        if weighted_scores.size(0) == ids.size(0):
+            top_ids = weighted_scores.topk(self.beam_size)[1]
+            return top_ids, top_ids
+
+        # mask pruned in pre-beam not to select in topk
+        tmp = weighted_scores[ids]
+        weighted_scores[:] = -float("inf")
+        weighted_scores[ids] = tmp
+        top_ids = weighted_scores.topk(self.beam_size)[1]
+        local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+        return top_ids, local_ids
+
+    @staticmethod
+    def merge_scores(
+        prev_scores: Dict[str, float],
+        next_full_scores: Dict[str, torch.Tensor],
+        full_idx: int,
+        next_part_scores: Dict[str, torch.Tensor],
+        part_idx: int,
+    ) -> Dict[str, torch.Tensor]:
+        """Merge scores for new hypothesis.
+
+        Args:
+            prev_scores (Dict[str, float]):
+                The previous hypothesis scores by `self.scorers`
+            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+            full_idx (int): The next token id for `next_full_scores`
+            next_part_scores (Dict[str, torch.Tensor]):
+                scores of partial tokens by `self.part_scorers`
+            part_idx (int): The new token id for `next_part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are scalar tensors by the scorers.
+
+        """
+        new_scores = dict()
+        for k, v in next_full_scores.items():
+            new_scores[k] = prev_scores[k] + v[full_idx]
+        for k, v in next_part_scores.items():
+            new_scores[k] = prev_scores[k] + v[part_idx]
+        return new_scores
+
+    def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+        """Merge states for new hypothesis.
+
+        Args:
+            states: states of `self.full_scorers`
+            part_states: states of `self.part_scorers`
+            part_idx (int): The new token id for `part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are states of the scorers.
+
+        """
+        new_states = dict()
+        for k, v in states.items():
+            new_states[k] = v
+        for k, d in self.part_scorers.items():
+            new_states[k] = d.select_state(part_states[k], part_idx)
+        return new_states
+
+    def search(
+        self, running_hyps: List[Hypothesis],
+        x: torch.Tensor,
+        x_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
+    ) -> List[Hypothesis]:
+        """Search new tokens for running hypotheses and encoded speech x.
+
+        Args:
+            running_hyps (List[Hypothesis]): Running hypotheses on beam
+            x (torch.Tensor): Encoded speech feature (T, D)
+
+        Returns:
+            List[Hypotheses]: Best sorted hypotheses
+
+        """
+        best_hyps = []
+        part_ids = torch.arange(self.n_vocab, device=x.device)  # no pre-beam
+        for hyp in running_hyps:
+            # scoring
+            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
+            scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
+            for k in self.full_scorers:
+                weighted_scores += self.weights[k] * scores[k]
+            # partial scoring
+            if self.do_pre_beam:
+                pre_beam_scores = (
+                    weighted_scores
+                    if self.pre_beam_score_key == "full"
+                    else scores[self.pre_beam_score_key]
+                )
+                part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+            part_scores, part_states = self.score_partial(hyp, part_ids, x)
+            for k in self.part_scorers:
+                weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+            # add previous hyp score
+            weighted_scores += hyp.score
+
+            # update hyps
+            for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+                # will be (2 x beam at most)
+                best_hyps.append(
+                    Hypothesis(
+                        score=weighted_scores[j],
+                        yseq=self.append_token(hyp.yseq, j),
+                        scores=self.merge_scores(
+                            hyp.scores, scores, j, part_scores, part_j
+                        ),
+                        states=self.merge_states(states, part_states, part_j),
+                    )
+                )
+
+            # sort and prune 2 x beam -> beam
+            best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+                : min(len(best_hyps), self.beam_size)
+            ]
+        return best_hyps
+
+    def forward(
+        self, x: torch.Tensor,
+        scama_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
+        maxlenratio: float = 0.0,
+        minlenratio: float = 0.0,
+        maxlen: int = None,
+        minlen: int = 0,
+    ) -> List[Hypothesis]:
+        """Perform beam search.
+
+        Args:
+            x (torch.Tensor): Encoded speech feature (T, D)
+            maxlenratio (float): Input length ratio to obtain max output length.
+                If maxlenratio=0.0 (default), it uses a end-detect function
+                to automatically find maximum hypothesis lengths
+                If maxlenratio<0.0, its absolute value is interpreted
+                as a constant max output length.
+            minlenratio (float): Input length ratio to obtain min output length.
+
+        Returns:
+            list[Hypothesis]: N-best decoding results
+
+        """
+        if maxlen is None:
+            # set length bounds
+            if maxlenratio == 0:
+                maxlen = x.shape[0]
+            elif maxlenratio < 0:
+                maxlen = -1 * int(maxlenratio)
+            else:
+                maxlen = max(1, int(maxlenratio * x.size(0)))
+            minlen = int(minlenratio * x.size(0))
+
+        logging.info("decoder input length: " + str(x.shape[0]))
+        logging.info("max output length: " + str(maxlen))
+        logging.info("min output length: " + str(minlen))
+
+        # main loop of prefix search
+        running_hyps = self.init_hyp(x)
+        ended_hyps = []
+        for i in range(maxlen):
+            logging.debug("position " + str(i))
+            mask_enc = None
+            if scama_mask is not None:
+                token_num_predictor = scama_mask.size(1)
+                token_id_slice = min(i, token_num_predictor-1)
+                mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
+                # if mask_enc.size(1) == 0:
+                #     mask_enc = scama_mask[:, -2:-1, :]
+                #     # mask_enc = torch.zeros_like(mask_enc)
+            pre_acoustic_embeds_cur = None
+            if pre_acoustic_embeds is not None:
+                b, t, d = pre_acoustic_embeds.size()
+                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
+                pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
+                token_id_slice = min(i, t)
+                pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
+
+            best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur)
+            # post process of one iteration
+            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+            # end detection
+            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+                logging.info(f"end detected at {i}")
+                break
+            if len(running_hyps) == 0:
+                logging.info("no hypothesis. Finish decoding.")
+                break
+            else:
+                logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+        # check the number of hypotheses reaching to eos
+        if len(nbest_hyps) == 0:
+            logging.warning(
+                "there is no N-best results, perform recognition "
+                "again with smaller minlenratio."
+            )
+            return (
+                []
+                if minlenratio < 0.1
+                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
+            )
+
+        # report the best result
+        for x in nbest_hyps:
+            yseq = "".join([self.token_list[x] for x in x.yseq])
+            logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
+        best = nbest_hyps[0]
+        for k, v in best.scores.items():
+            logging.info(
+                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
+            )
+        logging.info(f"total log probability: {best.score:.2f}")
+        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+        if self.token_list is not None:
+            logging.info(
+                "best hypo: "
+                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
+                + "\n"
+            )
+        return nbest_hyps
+
+    def post_process(
+        self,
+        i: int,
+        maxlen: int,
+        maxlenratio: float,
+        running_hyps: List[Hypothesis],
+        ended_hyps: List[Hypothesis],
+    ) -> List[Hypothesis]:
+        """Perform post-processing of beam search iterations.
+
+        Args:
+            i (int): The length of hypothesis tokens.
+            maxlen (int): The maximum length of tokens in beam search.
+            maxlenratio (int): The maximum length ratio in beam search.
+            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+        Returns:
+            List[Hypothesis]: The new running hypotheses.
+
+        """
+        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+        if self.token_list is not None:
+            logging.debug(
+                "best hypo: "
+                + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
+            )
+        # add eos in the final loop to avoid that there are no ended hyps
+        if i == maxlen - 1:
+            logging.info("adding <eos> in the last position in the loop")
+            running_hyps = [
+                h._replace(yseq=self.append_token(h.yseq, self.eos))
+                for h in running_hyps
+            ]
+
+        # add ended hypotheses to a final list, and removed them from current hypotheses
+        # (this will be a problem, number of hyps < beam)
+        remained_hyps = []
+        for hyp in running_hyps:
+            if hyp.yseq[-1] == self.eos:
+                # e.g., Word LM needs to add final <eos> score
+                for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+                    s = d.final_score(hyp.states[k])
+                    hyp.scores[k] += s
+                    hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+                ended_hyps.append(hyp)
+            else:
+                remained_hyps.append(hyp)
+        return remained_hyps

+ 276 - 388
funasr/models/uniasr/model.py

@@ -14,14 +14,13 @@ from funasr.models.ctc.ctc import CTC
 from funasr.utils import postprocess_utils
 from funasr.metrics.compute_acc import th_accuracy
 from funasr.utils.datadir_writer import DatadirWriter
-from funasr.models.paraformer.search import Hypothesis
 from funasr.models.paraformer.cif_predictor import mae_loss
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
+from funasr.models.scama.utils import sequence_mask
 
 @tables.register("model_classes", "UniASR")
 class UniASR(torch.nn.Module):
@@ -31,19 +30,37 @@ class UniASR(torch.nn.Module):
 
     def __init__(
         self,
-        specaug: Optional[str] = None,
-        specaug_conf: Optional[Dict] = None,
+        specaug: str = None,
+        specaug_conf: dict = None,
         normalize: str = None,
-        normalize_conf: Optional[Dict] = None,
+        normalize_conf: dict = None,
         encoder: str = None,
-        encoder_conf: Optional[Dict] = None,
+        encoder_conf: dict = None,
+        encoder2: str = None,
+        encoder2_conf: dict = None,
         decoder: str = None,
-        decoder_conf: Optional[Dict] = None,
-        ctc: str = None,
-        ctc_conf: Optional[Dict] = None,
+        decoder_conf: dict = None,
+        decoder2: str = None,
+        decoder2_conf: dict = None,
         predictor: str = None,
-        predictor_conf: Optional[Dict] = None,
+        predictor_conf: dict = None,
+        predictor_bias: int = 0,
+        predictor_weight: float = 0.0,
+        predictor2: str = None,
+        predictor2_conf: dict = None,
+        predictor2_bias: int = 0,
+        predictor2_weight: float = 0.0,
+        ctc: str = None,
+        ctc_conf: dict = None,
         ctc_weight: float = 0.5,
+        ctc2: str = None,
+        ctc2_conf: dict = None,
+        ctc2_weight: float = 0.5,
+        decoder_attention_chunk_type: str = 'chunk',
+        decoder_attention_chunk_type2: str = 'chunk',
+        stride_conv=None,
+        stride_conv_conf: dict = None,
+        loss_weight_model1: float = 0.5,
         input_size: int = 80,
         vocab_size: int = -1,
         ignore_id: int = -1,
@@ -52,60 +69,72 @@ class UniASR(torch.nn.Module):
         eos: int = 2,
         lsm_weight: float = 0.0,
         length_normalized_loss: bool = False,
-        # report_cer: bool = True,
-        # report_wer: bool = True,
-        # sym_space: str = "<space>",
-        # sym_blank: str = "<blank>",
-        # extract_feats_in_collect_stats: bool = True,
-        # predictor=None,
-        predictor_weight: float = 0.0,
-        predictor_bias: int = 0,
-        sampling_ratio: float = 0.2,
         share_embedding: bool = False,
-        # preencoder: Optional[AbsPreEncoder] = None,
-        # postencoder: Optional[AbsPostEncoder] = None,
-        use_1st_decoder_loss: bool = False,
-        encoder1_encoder2_joint_training: bool = True,
         **kwargs,
         
     ):
-        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
-        assert 0.0 <= interctc_weight < 1.0, interctc_weight
-
         super().__init__()
-        self.blank_id = 0
-        self.sos = 1
-        self.eos = 2
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+            
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        decoder_class = tables.decoder_classes.get(decoder)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            **decoder_conf,
+        )
+        predictor_class = tables.predictor_classes.get(predictor)
+        predictor = predictor_class(**predictor_conf)
+        
+
+        
+        from funasr.models.transformer.utils.subsampling import Conv1dSubsampling
+        stride_conv = Conv1dSubsampling(**stride_conv_conf, idim=input_size + encoder_output_size,
+                                        odim=input_size + encoder_output_size)
+        stride_conv_output_size = stride_conv.output_size()
+
+        encoder_class = tables.encoder_classes.get(encoder2)
+        encoder2 = encoder_class(input_size=stride_conv_output_size, **encoder2_conf)
+        encoder2_output_size = encoder2.output_size()
+
+        decoder_class = tables.decoder_classes.get(decoder2)
+        decoder2 = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder2_output_size,
+            **decoder2_conf,
+        )
+        predictor_class = tables.predictor_classes.get(predictor2)
+        predictor2 = predictor_class(**predictor2_conf)
+
+
+        
+        self.blank_id = blank_id
+        self.sos = sos
+        self.eos = eos
         self.vocab_size = vocab_size
         self.ignore_id = ignore_id
         self.ctc_weight = ctc_weight
-        self.interctc_weight = interctc_weight
-        self.token_list = token_list.copy()
+        self.ctc2_weight = ctc2_weight
 
-        self.frontend = frontend
         self.specaug = specaug
         self.normalize = normalize
-        self.preencoder = preencoder
-        self.postencoder = postencoder
+        
         self.encoder = encoder
 
-        if not hasattr(self.encoder, "interctc_use_conditioning"):
-            self.encoder.interctc_use_conditioning = False
-        if self.encoder.interctc_use_conditioning:
-            self.encoder.conditioning_layer = torch.nn.Linear(
-                vocab_size, self.encoder.output_size()
-            )
-
         self.error_calculator = None
 
-        # we set self.decoder = None in the CTC mode since
-        # self.decoder parameters were never used and PyTorch complained
-        # and threw an Exception in the multi-GPU experiment.
-        # thanks Jeff Farris for pointing out the issue.
-        if ctc_weight == 1.0:
-            self.decoder = None
-        else:
-            self.decoder = decoder
+        self.decoder = decoder
+        self.ctc = None
+        self.ctc2 = None
 
         self.criterion_att = LabelSmoothingLoss(
             size=vocab_size,
@@ -113,22 +142,13 @@ class UniASR(torch.nn.Module):
             smoothing=lsm_weight,
             normalize_length=length_normalized_loss,
         )
-
-        if report_cer or report_wer:
-            self.error_calculator = ErrorCalculator(
-                token_list, sym_space, sym_blank, report_cer, report_wer
-            )
-
-        if ctc_weight == 0.0:
-            self.ctc = None
-        else:
-            self.ctc = ctc
-
-        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+        
         self.predictor = predictor
         self.predictor_weight = predictor_weight
         self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
-        self.step_cur = 0
+        self.encoder1_encoder2_joint_training = kwargs.get("encoder1_encoder2_joint_training", True)
+        
+
         if self.encoder.overlap_chunk_cls is not None:
             from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
             self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
@@ -136,14 +156,10 @@ class UniASR(torch.nn.Module):
 
         self.encoder2 = encoder2
         self.decoder2 = decoder2
-        self.ctc_weight2 = ctc_weight2
-        if ctc_weight2 == 0.0:
-            self.ctc2 = None
-        else:
-            self.ctc2 = ctc2
-        self.interctc_weight2 = interctc_weight2
+        self.ctc2_weight = ctc2_weight
+
         self.predictor2 = predictor2
-        self.predictor_weight2 = predictor_weight2
+        self.predictor2_weight = predictor2_weight
         self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
         self.stride_conv = stride_conv
         self.loss_weight_model1 = loss_weight_model1
@@ -152,10 +168,10 @@ class UniASR(torch.nn.Module):
             self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder
             self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
 
-        self.enable_maas_finetune = enable_maas_finetune
-        self.freeze_encoder2 = freeze_encoder2
-        self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
         self.length_normalized_loss = length_normalized_loss
+        self.enable_maas_finetune = kwargs.get("enable_maas_finetune", False)
+        self.freeze_encoder2 = kwargs.get("freeze_encoder2", False)
+        self.beam_search = None
 
     def forward(
         self,
@@ -163,7 +179,7 @@ class UniASR(torch.nn.Module):
         speech_lengths: torch.Tensor,
         text: torch.Tensor,
         text_lengths: torch.Tensor,
-        decoding_ind: int = None,
+        **kwargs,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
         Args:
@@ -172,19 +188,14 @@ class UniASR(torch.nn.Module):
                         text: (Batch, Length)
                         text_lengths: (Batch,)
         """
-        assert text_lengths.dim() == 1, text_lengths.shape
-        # Check that batch_size is unified
-        assert (
-            speech.shape[0]
-            == speech_lengths.shape[0]
-            == text.shape[0]
-            == text_lengths.shape[0]
-        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+        decoding_ind = kwargs.get("decoding_ind", None)
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+
         batch_size = speech.shape[0]
 
-        # for data-parallel
-        text = text[:, : text_lengths.max()]
-        speech = speech[:, :speech_lengths.max()]
 
         ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
         # 1. Encoder
@@ -194,10 +205,6 @@ class UniASR(torch.nn.Module):
         else:
             speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
 
-        intermediate_outs = None
-        if isinstance(encoder_out, tuple):
-            intermediate_outs = encoder_out[1]
-            encoder_out = encoder_out[0]
 
         loss_att, acc_att, cer_att, wer_att = None, None, None, None
         loss_ctc, cer_ctc = None, None
@@ -210,62 +217,12 @@ class UniASR(torch.nn.Module):
             # 1. CTC branch
             if self.enable_maas_finetune:
                 with torch.no_grad():
-                    if self.ctc_weight != 0.0:
-                        if self.encoder.overlap_chunk_cls is not None:
-                            encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
-                                                                                                                encoder_out_lens,
-                                                                                                                chunk_outs=None)
-                        loss_ctc, cer_ctc = self._calc_ctc_loss(
-                            encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
-                        )
-
-                        # Collect CTC branch stats
-                        stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
-                        stats["cer_ctc"] = cer_ctc
-
-                    # Intermediate CTC (optional)
-                    loss_interctc = 0.0
-                    if self.interctc_weight != 0.0 and intermediate_outs is not None:
-                        for layer_idx, intermediate_out in intermediate_outs:
-                            # we assume intermediate_out has the same length & padding
-                            # as those of encoder_out
-                            if self.encoder.overlap_chunk_cls is not None:
-                                encoder_out_ctc, encoder_out_lens_ctc = \
-                                    self.encoder.overlap_chunk_cls.remove_chunk(
-                                        intermediate_out,
-                                        encoder_out_lens,
-                                        chunk_outs=None)
-                            loss_ic, cer_ic = self._calc_ctc_loss(
-                                encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
-                            )
-                            loss_interctc = loss_interctc + loss_ic
-
-                            # Collect Intermedaite CTC stats
-                            stats["loss_interctc_layer{}".format(layer_idx)] = (
-                                loss_ic.detach() if loss_ic is not None else None
-                            )
-                            stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
-                        loss_interctc = loss_interctc / len(intermediate_outs)
-
-                        # calculate whole encoder loss
-                        loss_ctc = (
-                                    1 - self.interctc_weight
-                                ) * loss_ctc + self.interctc_weight * loss_interctc
-
-                    # 2b. Attention decoder branch
-                    if self.ctc_weight != 1.0:
-                        loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
-                            encoder_out, encoder_out_lens, text, text_lengths
-                        )
-
-                    # 3. CTC-Att loss definition
-                    if self.ctc_weight == 0.0:
-                        loss = loss_att + loss_pre * self.predictor_weight
-                    elif self.ctc_weight == 1.0:
-                        loss = loss_ctc
-                    else:
-                        loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+
+                    loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
+                        encoder_out, encoder_out_lens, text, text_lengths
+                    )
+
+                    loss = loss_att + loss_pre * self.predictor_weight
 
                     # Collect Attn branch stats
                     stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -274,62 +231,13 @@ class UniASR(torch.nn.Module):
                     stats["wer"] = wer_att
                     stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
             else:
-                if self.ctc_weight != 0.0:
-                    if self.encoder.overlap_chunk_cls is not None:
-                        encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
-                                                                                                            encoder_out_lens,
-                                                                                                            chunk_outs=None)
-                    loss_ctc, cer_ctc = self._calc_ctc_loss(
-                        encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
-                    )
+                
+                loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
+                    encoder_out, encoder_out_lens, text, text_lengths
+                )
 
-                    # Collect CTC branch stats
-                    stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
-                    stats["cer_ctc"] = cer_ctc
-
-                    # Intermediate CTC (optional)
-                loss_interctc = 0.0
-                if self.interctc_weight != 0.0 and intermediate_outs is not None:
-                    for layer_idx, intermediate_out in intermediate_outs:
-                        # we assume intermediate_out has the same length & padding
-                        # as those of encoder_out
-                        if self.encoder.overlap_chunk_cls is not None:
-                            encoder_out_ctc, encoder_out_lens_ctc = \
-                                self.encoder.overlap_chunk_cls.remove_chunk(
-                                    intermediate_out,
-                                    encoder_out_lens,
-                                    chunk_outs=None)
-                        loss_ic, cer_ic = self._calc_ctc_loss(
-                            encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
-                        )
-                        loss_interctc = loss_interctc + loss_ic
-
-                        # Collect Intermedaite CTC stats
-                        stats["loss_interctc_layer{}".format(layer_idx)] = (
-                            loss_ic.detach() if loss_ic is not None else None
-                        )
-                        stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
-                    loss_interctc = loss_interctc / len(intermediate_outs)
-
-                    # calculate whole encoder loss
-                    loss_ctc = (
-                                1 - self.interctc_weight
-                            ) * loss_ctc + self.interctc_weight * loss_interctc
-
-                # 2b. Attention decoder branch
-                if self.ctc_weight != 1.0:
-                    loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
-                        encoder_out, encoder_out_lens, text, text_lengths
-                    )
 
-                # 3. CTC-Att loss definition
-                if self.ctc_weight == 0.0:
-                    loss = loss_att + loss_pre * self.predictor_weight
-                elif self.ctc_weight == 1.0:
-                    loss = loss_ctc
-                else:
-                    loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+                loss = loss_att + loss_pre * self.predictor_weight
 
                 # Collect Attn branch stats
                 stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -354,67 +262,14 @@ class UniASR(torch.nn.Module):
             if isinstance(encoder_out, tuple):
                 intermediate_outs = encoder_out[1]
                 encoder_out = encoder_out[0]
-            # CTC2
-            if self.ctc_weight2 != 0.0:
-                if self.encoder2.overlap_chunk_cls is not None:
-                    encoder_out_ctc, encoder_out_lens_ctc = \
-                        self.encoder2.overlap_chunk_cls.remove_chunk(
-                            encoder_out,
-                            encoder_out_lens,
-                            chunk_outs=None,
-                        )
-                loss_ctc, cer_ctc = self._calc_ctc_loss2(
-                    encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
-                )
-
-                # Collect CTC branch stats
-                stats["loss_ctc2"] = loss_ctc.detach() if loss_ctc is not None else None
-                stats["cer_ctc2"] = cer_ctc
-
-            # Intermediate CTC (optional)
-            loss_interctc = 0.0
-            if self.interctc_weight2 != 0.0 and intermediate_outs is not None:
-                for layer_idx, intermediate_out in intermediate_outs:
-                    # we assume intermediate_out has the same length & padding
-                    # as those of encoder_out
-                    if self.encoder2.overlap_chunk_cls is not None:
-                        encoder_out_ctc, encoder_out_lens_ctc = \
-                            self.encoder2.overlap_chunk_cls.remove_chunk(
-                                intermediate_out,
-                                encoder_out_lens,
-                                chunk_outs=None)
-                    loss_ic, cer_ic = self._calc_ctc_loss2(
-                        encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
-                    )
-                    loss_interctc = loss_interctc + loss_ic
-
-                    # Collect Intermedaite CTC stats
-                    stats["loss_interctc_layer{}2".format(layer_idx)] = (
-                        loss_ic.detach() if loss_ic is not None else None
-                    )
-                    stats["cer_interctc_layer{}2".format(layer_idx)] = cer_ic
 
-                loss_interctc = loss_interctc / len(intermediate_outs)
 
-                # calculate whole encoder loss
-                loss_ctc = (
-                               1 - self.interctc_weight2
-                           ) * loss_ctc + self.interctc_weight2 * loss_interctc
+            loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
 
-            # 2b. Attention decoder branch
-            if self.ctc_weight2 != 1.0:
-                loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
-                    encoder_out, encoder_out_lens, text, text_lengths
-                )
 
-            # 3. CTC-Att loss definition
-            if self.ctc_weight2 == 0.0:
-                loss = loss_att + loss_pre * self.predictor_weight2
-            elif self.ctc_weight2 == 1.0:
-                loss = loss_ctc
-            else:
-                loss = self.ctc_weight2 * loss_ctc + (
-                    1 - self.ctc_weight2) * loss_att + loss_pre * self.predictor_weight2
+            loss = loss_att + loss_pre * self.predictor2_weight
 
             # Collect Attn branch stats
             stats["loss_att2"] = loss_att.detach() if loss_att is not None else None
@@ -422,6 +277,7 @@ class UniASR(torch.nn.Module):
             stats["cer2"] = cer_att
             stats["wer2"] = wer_att
             stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None
+        
         loss2 = loss
 
         loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1)
@@ -456,61 +312,31 @@ class UniASR(torch.nn.Module):
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ):
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
                         speech: (Batch, Length, ...)
                         speech_lengths: (Batch, )
         """
+        ind = kwargs.get("ind", 0)
         with autocast(False):
-            # 1. Extract feats
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
-            # 2. Data augmentation
+            # Data augmentation
             if self.specaug is not None and self.training:
-                feats, feats_lengths = self.specaug(feats, feats_lengths)
-
-            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+    
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
             if self.normalize is not None:
-                feats, feats_lengths = self.normalize(feats, feats_lengths)
-        speech_raw = feats.clone().to(feats.device)
-        # Pre-encoder, e.g. used for raw input data
-        if self.preencoder is not None:
-            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+                
+        speech_raw = speech.clone().to(speech.device)
+
 
         # 4. Forward encoder
-        # feats: (Batch, Length, Dim)
-        # -> encoder_out: (Batch, Length2, Dim2)
-        if self.encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder(
-                feats, feats_lengths, ctc=self.ctc, ind=ind
-            )
-        else:
-            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
-        intermediate_outs = None
+        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
         if isinstance(encoder_out, tuple):
-            intermediate_outs = encoder_out[1]
             encoder_out = encoder_out[0]
 
-        # Post-encoder, e.g. NLU
-        if self.postencoder is not None:
-            encoder_out, encoder_out_lens = self.postencoder(
-                encoder_out, encoder_out_lens
-            )
-
-        assert encoder_out.size(0) == speech.size(0), (
-            encoder_out.size(),
-            speech.size(0),
-        )
-        assert encoder_out.size(1) <= encoder_out_lens.max(), (
-            encoder_out.size(),
-            encoder_out_lens.max(),
-        )
-
-        if intermediate_outs is not None:
-            return (encoder_out, intermediate_outs), encoder_out_lens
-
         return speech_raw, encoder_out, encoder_out_lens
 
     def encode2(
@@ -519,28 +345,15 @@ class UniASR(torch.nn.Module):
         encoder_out_lens: torch.Tensor,
         speech: torch.Tensor,
         speech_lengths: torch.Tensor,
-        ind: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        **kwargs,
+    ):
         """Frontend + Encoder. Note that this method is used by asr_inference.py
         Args:
                         speech: (Batch, Length, ...)
                         speech_lengths: (Batch, )
         """
-        # with autocast(False):
-        #     # 1. Extract feats
-        #     feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-        #
-        #     # 2. Data augmentation
-        #     if self.specaug is not None and self.training:
-        #         feats, feats_lengths = self.specaug(feats, feats_lengths)
-        #
-        #     # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-        #     if self.normalize is not None:
-        #         feats, feats_lengths = self.normalize(feats, feats_lengths)
-
-        # Pre-encoder, e.g. used for raw input data
-        # if self.preencoder is not None:
-        #     feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+        ind = kwargs.get("ind", 0)
         encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
             encoder_out,
             encoder_out_lens,
@@ -557,55 +370,14 @@ class UniASR(torch.nn.Module):
         # 4. Forward encoder
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
-        if self.encoder2.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder2(
-                speech, speech_lengths, ctc=self.ctc2, ind=ind
-            )
-        else:
-            encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
-        intermediate_outs = None
+
+        encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
         if isinstance(encoder_out, tuple):
-            intermediate_outs = encoder_out[1]
             encoder_out = encoder_out[0]
 
-        # # Post-encoder, e.g. NLU
-        # if self.postencoder is not None:
-        #     encoder_out, encoder_out_lens = self.postencoder(
-        #         encoder_out, encoder_out_lens
-        #     )
-
-        assert encoder_out.size(0) == speech.size(0), (
-            encoder_out.size(),
-            speech.size(0),
-        )
-        assert encoder_out.size(1) <= encoder_out_lens.max(), (
-            encoder_out.size(),
-            encoder_out_lens.max(),
-        )
-
-        if intermediate_outs is not None:
-            return (encoder_out, intermediate_outs), encoder_out_lens
 
         return encoder_out, encoder_out_lens
 
-    def _extract_feats(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        assert speech_lengths.dim() == 1, speech_lengths.shape
-
-        # for data-parallel
-        speech = speech[:, : speech_lengths.max()]
-
-        if self.frontend is not None:
-            # Frontend
-            #  e.g. STFT and Feature extract
-            #       data_loader may send time-domain signal in this case
-            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
-            feats, feats_lengths = self.frontend(speech, speech_lengths)
-        else:
-            # No frontend and no feature extract
-            feats, feats_lengths = speech, speech_lengths
-        return feats, feats_lengths
 
     def nll(
         self,
@@ -1024,36 +796,152 @@ class UniASR(torch.nn.Module):
 
         return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
 
-    def _calc_ctc_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
-    ):
-        # Calc CTC loss
-        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+    def init_beam_search(self,
+                         **kwargs,
+                         ):
+        from funasr.models.uniasr.beam_search import BeamSearchScama
+        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
 
-        # Calc CER using CTC
-        cer_ctc = None
-        if not self.training and self.error_calculator is not None:
-            ys_hat = self.ctc.argmax(encoder_out).data
-            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
-        return loss_ctc, cer_ctc
+        decoding_mode = kwargs.get("decoding_mode", "model1")
+        if decoding_mode == "model1":
+            decoder = self.decoder
+        else:
+            decoder = self.decoder2
+        # 1. Build ASR model
+        scorers = {}
+    
+        if self.ctc != None:
+            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+            scorers.update(
+                ctc=ctc
+            )
+        token_list = kwargs.get("token_list")
+        scorers.update(
+            decoder=decoder,
+            length_bonus=LengthBonus(len(token_list)),
+        )
+    
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+    
+        weights = dict(
+            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
+            ctc=kwargs.get("decoding_ctc_weight", 0.0),
+            lm=kwargs.get("lm_weight", 0.0),
+            ngram=kwargs.get("ngram_weight", 0.0),
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearchScama(
+            beam_size=kwargs.get("beam_size", 5),
+            weights=weights,
+            scorers=scorers,
+            sos=self.sos,
+            eos=self.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+        )
+        
+        self.beam_search = beam_search
+
+    def inference(self,
+                  data_in,
+                  data_lengths=None,
+                  key: list = None,
+                  tokenizer=None,
+                  frontend=None,
+                  **kwargs,
+                  ):
+
+        decoding_model = kwargs.get("decoding_model", "normal")
+        token_num_relax = kwargs.get("token_num_relax", 5)
+        if decoding_model == "fast":
+            decoding_ind = 0
+            decoding_mode = "model1"
+        elif decoding_model == "offline":
+            decoding_ind = 1
+            decoding_mode = "model2"
+        else:
+            decoding_ind = 0
+            decoding_mode = "model2"
+        # init beamsearch
+        
+        if self.beam_search is None:
+            logging.info("enable beam_search")
+            self.init_beam_search(decoding_mode=decoding_mode, **kwargs)
+            self.nbest = kwargs.get("nbest", 1)
+    
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+                                                            data_type=kwargs.get("data_type", "sound"),
+                                                            tokenizer=tokenizer)
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+    
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        speech_raw = speech.clone().to(device=kwargs["device"])
+        # Encoder
+        _, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=decoding_ind)
+        if decoding_mode == "model1":
+            predictor_outs = self.calc_predictor_mask(encoder_out, encoder_out_lens)
+        else:
+            encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=decoding_ind)
+            predictor_outs = self.calc_predictor_mask2(encoder_out, encoder_out_lens)
+
+
+        scama_mask = predictor_outs[4]
+        pre_token_length = predictor_outs[1]
+        pre_acoustic_embeds = predictor_outs[0]
+        maxlen = pre_token_length.sum().item() + token_num_relax
+        minlen = max(0, pre_token_length.sum().item() - token_num_relax)
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=0.0,
+            minlenratio=0.0, maxlen=int(maxlen), minlen=int(minlen),
+        )
 
-    def _calc_ctc_loss2(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
-    ):
-        # Calc CTC loss
-        loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-
-        # Calc CER using CTC
-        cer_ctc = None
-        if not self.training and self.error_calculator is not None:
-            ys_hat = self.ctc2.argmax(encoder_out).data
-            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
-        return loss_ctc, cer_ctc
+        nbest_hyps = nbest_hyps[: self.nbest]
+
+        results = []
+        for hyp in nbest_hyps:
+
+            # remove sos/eos and get results
+            last_pos = -1
+            if isinstance(hyp.yseq, list):
+                token_int = hyp.yseq[1:last_pos]
+            else:
+                token_int = hyp.yseq[1:last_pos].tolist()
+
+            # remove blank symbol id, which is assumed to be 0
+            token_int = list(filter(lambda x: x != 0, token_int))
+
+
+            # Change integer-ids to tokens
+            token = tokenizer.ids2tokens(token_int)
+            text_postprocessed = tokenizer.tokens2text(token)
+            if not hasattr(tokenizer, "bpemodel"):
+                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+    
+
+            result_i = {"key": key[0], "text": text_postprocessed}
+            results.append(result_i)
+
+        return results, meta_data