فهرست منبع

update funasr 0.1.3

lzr265946 3 سال پیش
والد
کامیت
a9e857e452

+ 1 - 0
egs_modelscope/common/modelscope_utils/modelscope_infer.sh

@@ -65,6 +65,7 @@ for dset in ${test_sets}; do
     ${decode_cmd} --max-jobs-run "${inference_nj}" JOB=1:"${inference_nj}" "${_logdir}"/asr_inference.JOB.log \
         python -m funasr.bin.modelscope_infer \
               --model_name ${model_name} \
+              --model_revision ${model_revision} \
               --wav_list ${_logdir}/keys.JOB.scp \
               --output_file ${_logdir}/text.JOB \
               --gpuid_list ${gpuid_list} \

+ 687 - 0
funasr/bin/asr_inference_modelscope.py

@@ -0,0 +1,687 @@
+#!/usr/bin/env python3
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+import argparse
+import logging
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
+from funasr.modules.beam_search.batch_beam_search_online_sim import BatchBeamSearchOnlineSim
+from funasr.modules.beam_search.beam_search import BeamSearch
+from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.asr import ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+global_asr_language: str = 'zh-cn'
+global_sample_rate: Union[int, Dict[Any, int]] = {
+    'audio_fs': 16000,
+    'model_fs': 16000
+}
+
+class Speech2Text:
+    """Speech2Text class
+
+    Examples:
+        >>> import soundfile
+        >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+        >>> audio, rate = soundfile.read("speech.wav")
+        >>> speech2text(audio)
+        [(text, token, token_int, hypothesis object), ...]
+
+    """
+
+    def __init__(
+            self,
+            asr_train_config: Union[Path, str] = None,
+            asr_model_file: Union[Path, str] = None,
+            lm_train_config: Union[Path, str] = None,
+            lm_file: Union[Path, str] = None,
+            token_type: str = None,
+            bpemodel: str = None,
+            device: str = "cpu",
+            maxlenratio: float = 0.0,
+            minlenratio: float = 0.0,
+            batch_size: int = 1,
+            dtype: str = "float32",
+            beam_size: int = 20,
+            ctc_weight: float = 0.5,
+            lm_weight: float = 1.0,
+            ngram_weight: float = 0.9,
+            penalty: float = 0.0,
+            nbest: int = 1,
+            streaming: bool = False,
+            frontend_conf: dict = None,
+            **kwargs,
+    ):
+        assert check_argument_types()
+
+        # 1. Build ASR model
+        scorers = {}
+        asr_model, asr_train_args = ASRTask.build_model_from_file(
+            asr_train_config, asr_model_file, device
+        )
+        if asr_model.frontend is None and frontend_conf is not None:
+            frontend = WavFrontend(**frontend_conf)
+            asr_model.frontend = frontend
+        asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+        decoder = asr_model.decoder
+
+        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+        token_list = asr_model.token_list
+        scorers.update(
+            decoder=decoder,
+            ctc=ctc,
+            length_bonus=LengthBonus(len(token_list)),
+        )
+
+        # 2. Build Language model
+        if lm_train_config is not None:
+            lm, lm_train_args = LMTask.build_model_from_file(
+                lm_train_config, lm_file, device
+            )
+            scorers["lm"] = lm.lm
+
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+
+        # 4. Build BeamSearch object
+        # transducer is not supported now
+        beam_search_transducer = None
+
+        weights = dict(
+            decoder=1.0 - ctc_weight,
+            ctc=ctc_weight,
+            lm=lm_weight,
+            ngram=ngram_weight,
+            length_bonus=penalty,
+        )
+        beam_search = BeamSearch(
+            beam_size=beam_size,
+            weights=weights,
+            scorers=scorers,
+            sos=asr_model.sos,
+            eos=asr_model.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+        )
+
+        # TODO(karita): make all scorers batchfied
+        if batch_size == 1:
+            non_batch = [
+                k
+                for k, v in beam_search.full_scorers.items()
+                if not isinstance(v, BatchScorerInterface)
+            ]
+            if len(non_batch) == 0:
+                if streaming:
+                    beam_search.__class__ = BatchBeamSearchOnlineSim
+                    beam_search.set_streaming_config(asr_train_config)
+                    logging.info(
+                        "BatchBeamSearchOnlineSim implementation is selected."
+                    )
+                else:
+                    beam_search.__class__ = BatchBeamSearch
+                    logging.info("BatchBeamSearch implementation is selected.")
+            else:
+                logging.warning(
+                    f"As non-batch scorers {non_batch} are found, "
+                    f"fall back to non-batch implementation."
+                )
+
+            beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+            for scorer in scorers.values():
+                if isinstance(scorer, torch.nn.Module):
+                    scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+            logging.info(f"Beam_search: {beam_search}")
+            logging.info(f"Decoding device={device}, dtype={dtype}")
+
+        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+        if token_type is None:
+            token_type = asr_train_args.token_type
+        if bpemodel is None:
+            bpemodel = asr_train_args.bpemodel
+
+        if token_type is None:
+            tokenizer = None
+        elif token_type == "bpe":
+            if bpemodel is not None:
+                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+            else:
+                tokenizer = None
+        else:
+            tokenizer = build_tokenizer(token_type=token_type)
+        converter = TokenIDConverter(token_list=token_list)
+        logging.info(f"Text tokenizer: {tokenizer}")
+
+        self.asr_model = asr_model
+        self.asr_train_args = asr_train_args
+        self.converter = converter
+        self.tokenizer = tokenizer
+        self.beam_search = beam_search
+        self.beam_search_transducer = beam_search_transducer
+        self.maxlenratio = maxlenratio
+        self.minlenratio = minlenratio
+        self.device = device
+        self.dtype = dtype
+        self.nbest = nbest
+
+    @torch.no_grad()
+    def __call__(
+            self, speech: Union[torch.Tensor, np.ndarray]
+    ) -> List[
+        Tuple[
+            Optional[str],
+            List[str],
+            List[int],
+            Union[Hypothesis],
+        ]
+    ]:
+        """Inference
+
+        Args:
+            speech: Input speech data
+        Returns:
+            text, token, token_int, hyp
+
+        """
+        assert check_argument_types()
+
+        # Input as audio signal
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+
+        # data: (Nsamples,) -> (1, Nsamples)
+        speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
+        # lengths: (1,)
+        lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+        batch = {"speech": speech, "speech_lengths": lengths}
+
+        # a. To device
+        batch = to_device(batch, device=self.device)
+
+        # b. Forward Encoder
+        enc, _ = self.asr_model.encode(**batch)
+        if isinstance(enc, tuple):
+            enc = enc[0]
+        assert len(enc) == 1, len(enc)
+
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+        )
+
+        nbest_hyps = nbest_hyps[: self.nbest]
+
+        results = []
+        for hyp in nbest_hyps:
+            assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+            # remove sos/eos and get results
+            last_pos = -1
+            if isinstance(hyp.yseq, list):
+                token_int = hyp.yseq[1:last_pos]
+            else:
+                token_int = hyp.yseq[1:last_pos].tolist()
+
+            # remove blank symbol id, which is assumed to be 0
+            token_int = list(filter(lambda x: x != 0, token_int))
+
+            # Change integer-ids to tokens
+            token = self.converter.ids2tokens(token_int)
+
+            if self.tokenizer is not None:
+                text = self.tokenizer.tokens2text(token)
+            else:
+                text = None
+            results.append((text, token, token_int, hyp))
+
+        assert check_return_type(results)
+        return results
+
+
+def inference(
+        maxlenratio: float,
+        minlenratio: float,
+        batch_size: int,
+        dtype: str,
+        beam_size: int,
+        ngpu: int,
+        seed: int,
+        ctc_weight: float,
+        lm_weight: float,
+        ngram_weight: float,
+        penalty: float,
+        nbest: int,
+        num_workers: int,
+        log_level: Union[int, str],
+        data_path_and_name_and_type: list,
+        audio_lists: Union[List[Any], bytes],
+        key_file: Optional[str],
+        asr_train_config: Optional[str],
+        asr_model_file: Optional[str],
+        lm_train_config: Optional[str],
+        lm_file: Optional[str],
+        word_lm_train_config: Optional[str],
+        token_type: Optional[str],
+        bpemodel: Optional[str],
+        output_dir: Optional[str],
+        allow_variable_data_keys: bool,
+        streaming: bool,
+        frontend_conf: dict = None,
+        fs: Union[dict, int] = 16000,
+        **kwargs,
+) -> List[Any]:
+    assert check_argument_types()
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
+    if word_lm_train_config is not None:
+        raise NotImplementedError("Word LM is not implemented")
+    if ngpu > 1:
+        raise NotImplementedError("only single GPU decoding is supported")
+
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+
+    if ngpu >= 1:
+        device = "cuda"
+    else:
+        device = "cpu"
+    features_type: str = data_path_and_name_and_type[1]
+    hop_length: int = 160
+    sr: int = 16000
+    if isinstance(fs, int):
+        sr = fs
+    else:
+        if 'model_fs' in fs and fs['model_fs'] is not None:
+            sr = fs['model_fs']
+    if features_type != 'sound':
+        frontend_conf = None
+    if frontend_conf is not None:
+        if 'hop_length' in frontend_conf:
+            hop_length = frontend_conf['hop_length']
+
+    finish_count = 0
+    file_count = 1
+    if isinstance(audio_lists, bytes):
+        file_count = 1
+    else:
+        file_count = len(audio_lists)
+    if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
+        mvn_file = data_path_and_name_and_type[2]
+        mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
+        frontend_conf['mvn_data'] = mvn_data
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+
+    # 2. Build speech2text
+    speech2text_kwargs = dict(
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        bpemodel=bpemodel,
+        device=device,
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        dtype=dtype,
+        beam_size=beam_size,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        ngram_weight=ngram_weight,
+        penalty=penalty,
+        nbest=nbest,
+        streaming=streaming,
+        frontend_conf=frontend_conf,
+    )
+    speech2text = Speech2Text(**speech2text_kwargs)
+    data_path_and_name_and_type_new = [
+        audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
+    ]
+    # 3. Build data-iterator
+    loader = ASRTask.build_streaming_iterator_modelscope(
+        data_path_and_name_and_type_new,
+        dtype=dtype,
+        batch_size=batch_size,
+        key_file=key_file,
+        num_workers=num_workers,
+        preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+        collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+        allow_variable_data_keys=allow_variable_data_keys,
+        inference=True,
+        sample_rate=fs
+    )
+
+    # 7 .Start for-loop
+    # FIXME(kamo): The output format should be discussed about
+    asr_result_list = []
+    for keys, batch in loader:
+        assert isinstance(batch, dict), type(batch)
+        assert all(isinstance(s, str) for s in keys), keys
+        _bs = len(next(iter(batch.values())))
+        assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+        batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+        # N-best list of (text, token, token_int, hyp_object)
+        try:
+            results = speech2text(**batch)
+        except TooShortUttError as e:
+            logging.warning(f"Utterance {keys} {e}")
+            hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+            results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+        # Only supporting batch_size==1
+        key = keys[0]
+        for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+            if text is not None:
+                text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                item = {'key': key, 'value': text_postprocessed}
+                asr_result_list.append(item)
+        finish_count += 1
+        asr_utils.print_progress(finish_count / file_count)
+
+    return asr_result_list
+
+
+
+def set_parameters(language: str = None,
+                   sample_rate: Union[int, Dict[Any, int]] = None):
+    if language is not None:
+        global global_asr_language
+        global_asr_language = language
+    if sample_rate is not None:
+        global global_sample_rate
+        global_sample_rate = sample_rate
+
+
+def asr_inference(maxlenratio: float,
+                  minlenratio: float,
+                  beam_size: int,
+                  ngpu: int,
+                  ctc_weight: float,
+                  lm_weight: float,
+                  penalty: float,
+                  name_and_type: list,
+                  audio_lists: Union[List[Any], bytes],
+                  asr_train_config: Optional[str],
+                  asr_model_file: Optional[str],
+                  nbest: int = 1,
+                  num_workers: int = 1,
+                  log_level: Union[int, str] = 'INFO',
+                  batch_size: int = 1,
+                  dtype: str = 'float32',
+                  seed: int = 0,
+                  key_file: Optional[str] = None,
+                  lm_train_config: Optional[str] = None,
+                  lm_file: Optional[str] = None,
+                  word_lm_train_config: Optional[str] = None,
+                  word_lm_file: Optional[str] = None,
+                  ngram_file: Optional[str] = None,
+                  ngram_weight: float = 0.9,
+                  model_tag: Optional[str] = None,
+                  token_type: Optional[str] = None,
+                  bpemodel: Optional[str] = None,
+                  allow_variable_data_keys: bool = False,
+                  transducer_conf: Optional[dict] = None,
+                  streaming: bool = False,
+                  frontend_conf: dict = None,
+                  fs: Union[dict, int] = None,
+                  lang: Optional[str] = None,
+                  outputdir: Optional[str] = None):
+    if lang is not None:
+        global global_asr_language
+        global_asr_language = lang
+    if fs is not None:
+        global global_sample_rate
+        global_sample_rate = fs
+
+    # force use CPU if data type is bytes
+    if isinstance(audio_lists, bytes):
+        num_workers = 0
+        ngpu = 0
+
+    return inference(output_dir=outputdir,
+                     maxlenratio=maxlenratio,
+                     minlenratio=minlenratio,
+                     batch_size=batch_size,
+                     dtype=dtype,
+                     beam_size=beam_size,
+                     ngpu=ngpu,
+                     seed=seed,
+                     ctc_weight=ctc_weight,
+                     lm_weight=lm_weight,
+                     ngram_weight=ngram_weight,
+                     penalty=penalty,
+                     nbest=nbest,
+                     num_workers=num_workers,
+                     log_level=log_level,
+                     data_path_and_name_and_type=name_and_type,
+                     audio_lists=audio_lists,
+                     key_file=key_file,
+                     asr_train_config=asr_train_config,
+                     asr_model_file=asr_model_file,
+                     lm_train_config=lm_train_config,
+                     lm_file=lm_file,
+                     word_lm_train_config=word_lm_train_config,
+                     word_lm_file=word_lm_file,
+                     ngram_file=ngram_file,
+                     model_tag=model_tag,
+                     token_type=token_type,
+                     bpemodel=bpemodel,
+                     allow_variable_data_keys=allow_variable_data_keys,
+                     transducer_conf=transducer_conf,
+                     streaming=streaming,
+                     frontend_conf=frontend_conf)
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="ASR Decoding",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Note(kamo): Use '_' instead of '-' as separator.
+    # '-' is confusing if written in yaml.
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, required=True)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument(
+        "--gpuid_list",
+        type=str,
+        default="",
+        help="The visible gpus",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        required=True,
+        action="append",
+    )
+    group.add_argument("--audio_lists", type=list,
+                       default=[{'key':'EdevDEWdIYQ_0021',
+                                 'file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+    group.add_argument("--key_file", type=str_or_none)
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument(
+        "--asr_train_config",
+        type=str,
+        help="ASR training configuration",
+    )
+    group.add_argument(
+        "--asr_model_file",
+        type=str,
+        help="ASR model parameter file",
+    )
+    group.add_argument(
+        "--lm_train_config",
+        type=str,
+        help="LM training configuration",
+    )
+    group.add_argument(
+        "--lm_file",
+        type=str,
+        help="LM parameter file",
+    )
+    group.add_argument(
+        "--word_lm_train_config",
+        type=str,
+        help="Word LM training configuration",
+    )
+    group.add_argument(
+        "--word_lm_file",
+        type=str,
+        help="Word LM parameter file",
+    )
+    group.add_argument(
+        "--ngram_file",
+        type=str,
+        help="N-gram parameter file",
+    )
+    group.add_argument(
+        "--model_tag",
+        type=str,
+        help="Pretrained model tag. If specify this option, *_train_config and "
+             "*_file will be overwritten",
+    )
+
+    group = parser.add_argument_group("Beam-search related")
+    group.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+    group.add_argument(
+        "--maxlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain max output length. "
+             "If maxlenratio=0.0 (default), it uses a end-detect "
+             "function "
+             "to automatically find maximum hypothesis lengths."
+             "If maxlenratio<0.0, its absolute value is interpreted"
+             "as a constant max output length",
+    )
+    group.add_argument(
+        "--minlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain min output length",
+    )
+    group.add_argument(
+        "--ctc_weight",
+        type=float,
+        default=0.5,
+        help="CTC weight in joint decoding",
+    )
+    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+    group.add_argument("--streaming", type=str2bool, default=False)
+
+    group = parser.add_argument_group("Text converter related")
+    group.add_argument(
+        "--token_type",
+        type=str_or_none,
+        default=None,
+        choices=["char", "bpe", None],
+        help="The token type for ASR model. "
+             "If not given, refers from the training args",
+    )
+    group.add_argument(
+        "--bpemodel",
+        type=str_or_none,
+        default=None,
+        help="The model path of sentencepiece. "
+             "If not given, refers from the training args",
+    )
+
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    kwargs.pop("config", None)
+    inference(**kwargs)
+
+
+if __name__ == "__main__":
+    main()

+ 686 - 0
funasr/bin/asr_inference_paraformer_modelscope.py

@@ -0,0 +1,686 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+import sys
+import time
+from pathlib import Path
+from typing import Any
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import List
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+
+from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
+from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+global_asr_language: str = 'zh-cn'
+global_sample_rate: Union[int, Dict[Any, int]] = {
+    'audio_fs': 16000,
+    'model_fs': 16000
+}
+
+
+class Speech2Text:
+    """Speech2Text class
+
+    Examples:
+            >>> import soundfile
+            >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+            >>> audio, rate = soundfile.read("speech.wav")
+            >>> speech2text(audio)
+            [(text, token, token_int, hypothesis object), ...]
+
+    """
+
+    def __init__(
+            self,
+            asr_train_config: Union[Path, str] = None,
+            asr_model_file: Union[Path, str] = None,
+            lm_train_config: Union[Path, str] = None,
+            lm_file: Union[Path, str] = None,
+            token_type: str = None,
+            bpemodel: str = None,
+            device: str = "cpu",
+            maxlenratio: float = 0.0,
+            minlenratio: float = 0.0,
+            dtype: str = "float32",
+            beam_size: int = 20,
+            ctc_weight: float = 0.5,
+            lm_weight: float = 1.0,
+            ngram_weight: float = 0.9,
+            penalty: float = 0.0,
+            nbest: int = 1,
+            frontend_conf: dict = None,
+            **kwargs,
+    ):
+        assert check_argument_types()
+
+        # 1. Build ASR model
+        scorers = {}
+        asr_model, asr_train_args = ASRTask.build_model_from_file(
+            asr_train_config, asr_model_file, device
+        )
+        if asr_model.frontend is None and frontend_conf is not None:
+            frontend = WavFrontend(**frontend_conf)
+            asr_model.frontend = frontend
+        asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+        token_list = asr_model.token_list
+        scorers.update(
+            ctc=ctc,
+            length_bonus=LengthBonus(len(token_list)),
+        )
+
+        # 2. Build Language model
+        if lm_train_config is not None:
+            lm, lm_train_args = LMTask.build_model_from_file(
+                lm_train_config, lm_file, device
+            )
+            scorers["lm"] = lm.lm
+
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+
+        # 4. Build BeamSearch object
+        # transducer is not supported now
+        beam_search_transducer = None
+
+        weights = dict(
+            decoder=1.0 - ctc_weight,
+            ctc=ctc_weight,
+            lm=lm_weight,
+            ngram=ngram_weight,
+            length_bonus=penalty,
+        )
+        beam_search = BeamSearch(
+            beam_size=beam_size,
+            weights=weights,
+            scorers=scorers,
+            sos=asr_model.sos,
+            eos=asr_model.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+        )
+
+        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+        for scorer in scorers.values():
+            if isinstance(scorer, torch.nn.Module):
+                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+        logging.info(f"Beam_search: {beam_search}")
+        logging.info(f"Decoding device={device}, dtype={dtype}")
+
+        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+        if token_type is None:
+            token_type = asr_train_args.token_type
+        if bpemodel is None:
+            bpemodel = asr_train_args.bpemodel
+
+        if token_type is None:
+            tokenizer = None
+        elif token_type == "bpe":
+            if bpemodel is not None:
+                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+            else:
+                tokenizer = None
+        else:
+            tokenizer = build_tokenizer(token_type=token_type)
+        converter = TokenIDConverter(token_list=token_list)
+        logging.info(f"Text tokenizer: {tokenizer}")
+
+        self.asr_model = asr_model
+        self.asr_train_args = asr_train_args
+        self.converter = converter
+        self.tokenizer = tokenizer
+        self.beam_search = beam_search
+        self.beam_search_transducer = beam_search_transducer
+        self.maxlenratio = maxlenratio
+        self.minlenratio = minlenratio
+        self.device = device
+        self.dtype = dtype
+        self.nbest = nbest
+
+    @torch.no_grad()
+    def __call__(
+            self, speech: Union[torch.Tensor, np.ndarray]
+    ):
+        """Inference
+
+        Args:
+                speech: Input speech data
+        Returns:
+                text, token, token_int, hyp
+
+        """
+        assert check_argument_types()
+
+        # Input as audio signal
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+
+        # data: (Nsamples,) -> (1, Nsamples)
+        speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
+        # lengths: (1,)
+        lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+        batch = {"speech": speech, "speech_lengths": lengths}
+
+        # a. To device
+        batch = to_device(batch, device=self.device)
+
+        # b. Forward Encoder
+        enc, enc_len = self.asr_model.encode(**batch)
+        if isinstance(enc, tuple):
+            enc = enc[0]
+        assert len(enc) == 1, len(enc)
+
+        predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
+        pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
+        pre_token_length = torch.tensor([pre_acoustic_embeds.size(1)], device=pre_acoustic_embeds.device)
+        decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+        nbest_hyps = self.beam_search(
+            x=enc[0], am_scores=decoder_out[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+        )
+
+        nbest_hyps = nbest_hyps[: self.nbest]
+        results = []
+        for hyp in nbest_hyps:
+            assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+            # remove sos/eos and get results
+            last_pos = -1
+            if isinstance(hyp.yseq, list):
+                token_int = hyp.yseq[1:last_pos]
+            else:
+                token_int = hyp.yseq[1:last_pos].tolist()
+
+            # remove blank symbol id, which is assumed to be 0
+            token_int = list(filter(lambda x: x != 0, token_int))
+
+            # Change integer-ids to tokens
+            token = self.converter.ids2tokens(token_int)
+
+            if self.tokenizer is not None:
+                text = self.tokenizer.tokens2text(token)
+            else:
+                text = None
+
+            results.append((text, token, token_int, hyp, speech.size(1), lfr_factor))
+
+        # assert check_return_type(results)
+        return results
+
+
+def inference(
+        maxlenratio: float,
+        minlenratio: float,
+        batch_size: int,
+        dtype: str,
+        beam_size: int,
+        ngpu: int,
+        seed: int,
+        ctc_weight: float,
+        lm_weight: float,
+        ngram_weight: float,
+        penalty: float,
+        nbest: int,
+        num_workers: int,
+        log_level: Union[int, str],
+        data_path_and_name_and_type: list,
+        audio_lists: Union[List[Any], bytes],
+        key_file: Optional[str],
+        asr_train_config: Optional[str],
+        asr_model_file: Optional[str],
+        lm_train_config: Optional[str],
+        lm_file: Optional[str],
+        word_lm_train_config: Optional[str],
+        model_tag: Optional[str],
+        token_type: Optional[str],
+        bpemodel: Optional[str],
+        output_dir: Optional[str],
+        allow_variable_data_keys: bool,
+        frontend_conf: dict = None,
+        fs: Union[dict, int] = 16000,
+        **kwargs,
+) -> List[Any]:
+    assert check_argument_types()
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
+    if word_lm_train_config is not None:
+        raise NotImplementedError("Word LM is not implemented")
+    if ngpu > 1:
+        raise NotImplementedError("only single GPU decoding is supported")
+
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+
+    if ngpu >= 1:
+        device = "cuda"
+    else:
+        device = "cpu"
+    # data_path_and_name_and_type = data_path_and_name_and_type[0]
+    features_type: str = data_path_and_name_and_type[1]
+    hop_length: int = 160
+    sr: int = 16000
+    if isinstance(fs, int):
+        sr = fs
+    else:
+        if 'model_fs' in fs and fs['model_fs'] is not None:
+            sr = fs['model_fs']
+    if features_type != 'sound':
+        frontend_conf = None
+    if frontend_conf is not None:
+        if 'hop_length' in frontend_conf:
+            hop_length = frontend_conf['hop_length']
+
+    finish_count = 0
+    file_count = 1
+    if isinstance(audio_lists, bytes):
+        file_count = 1
+    else:
+        file_count = len(audio_lists)
+    if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
+        mvn_file = data_path_and_name_and_type[2]
+        mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
+        frontend_conf['mvn_data'] = mvn_data
+
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+
+    # 2. Build speech2text
+    speech2text_kwargs = dict(
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        bpemodel=bpemodel,
+        device=device,
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        dtype=dtype,
+        beam_size=beam_size,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        ngram_weight=ngram_weight,
+        penalty=penalty,
+        nbest=nbest,
+        frontend_conf=frontend_conf,
+    )
+    speech2text = Speech2Text(**speech2text_kwargs)
+
+    data_path_and_name_and_type_new = [
+        audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
+    ]
+
+    # 3. Build data-iterator
+    loader = ASRTask.build_streaming_iterator_modelscope(
+        data_path_and_name_and_type_new,
+        dtype=dtype,
+        batch_size=batch_size,
+        key_file=key_file,
+        num_workers=num_workers,
+        preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+        collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+        allow_variable_data_keys=allow_variable_data_keys,
+        inference=True,
+        sample_rate=fs
+    )
+
+    forward_time_total = 0.0
+    length_total = 0.0
+    asr_result_list = []
+    # 7 .Start for-loop
+    # FIXME(kamo): The output format should be discussed about
+    for keys, batch in loader:
+        assert isinstance(batch, dict), type(batch)
+        assert all(isinstance(s, str) for s in keys), keys
+        _bs = len(next(iter(batch.values())))
+        assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+        batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+
+        logging.info("decoding, utt_id: {}".format(keys))
+        # N-best list of (text, token, token_int, hyp_object)
+
+        try:
+            time_beg = time.time()
+            results = speech2text(**batch)
+            time_end = time.time()
+            forward_time = time_end - time_beg
+            lfr_factor = results[0][-1]
+            length = results[0][-2]
+            results = [results[0][:-2]]
+            forward_time_total += forward_time
+            length_total += length
+            logging.info(
+                "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
+                    format(length, forward_time, 100 * forward_time / (length * lfr_factor)))
+        except TooShortUttError as e:
+            logging.warning(f"Utterance {keys} {e}")
+            hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+            results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+        # Only supporting batch_size==1
+        key = keys[0]
+        for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+            if text is not None:
+                text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                item = {'key': key, 'value': text_postprocessed}
+                asr_result_list.append(item)
+
+            logging.info("decoding, predictions: {}".format(text))
+        finish_count += 1
+        asr_utils.print_progress(finish_count / file_count)
+
+    logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
+                 format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)))
+    if features_type == 'sound':
+        # data format is wav
+        length_total_seconds = length_total / sr
+        length_total_bytes = length_total * 2
+    else:
+        # data format is kaldi_ark
+        length_total_seconds = length_total * hop_length / sr
+        length_total_bytes = length_total * hop_length * 2
+
+    logger.info(
+        header_colors +  # noqa: *
+        'decoding, feature length total: {}bytes, forward_time total: {:.4f}s, rtf avg: {:.4f}'
+        .format(length_total_bytes, forward_time_total, forward_time_total /
+                length_total_seconds) + end_colors)
+
+    return asr_result_list
+
+
+def set_parameters(language: str = None,
+                   sample_rate: Union[int, Dict[Any, int]] = None):
+    if language is not None:
+        global global_asr_language
+        global_asr_language = language
+    if sample_rate is not None:
+        global global_sample_rate
+        global_sample_rate = sample_rate
+
+
+def asr_inference(maxlenratio: float,
+                  minlenratio: float,
+                  beam_size: int,
+                  ngpu: int,
+                  ctc_weight: float,
+                  lm_weight: float,
+                  penalty: float,
+                  name_and_type: list,
+                  audio_lists: Union[List[Any], bytes],
+                  asr_train_config: Optional[str],
+                  asr_model_file: Optional[str],
+                  nbest: int = 1,
+                  num_workers: int = 1,
+                  log_level: Union[int, str] = 'INFO',
+                  batch_size: int = 1,
+                  dtype: str = 'float32',
+                  seed: int = 0,
+                  key_file: Optional[str] = None,
+                  lm_train_config: Optional[str] = None,
+                  lm_file: Optional[str] = None,
+                  word_lm_train_config: Optional[str] = None,
+                  word_lm_file: Optional[str] = None,
+                  ngram_file: Optional[str] = None,
+                  ngram_weight: float = 0.9,
+                  model_tag: Optional[str] = None,
+                  token_type: Optional[str] = None,
+                  bpemodel: Optional[str] = None,
+                  allow_variable_data_keys: bool = False,
+                  transducer_conf: Optional[dict] = None,
+                  streaming: bool = False,
+                  frontend_conf: dict = None,
+                  fs: Union[dict, int] = None,
+                  lang: Optional[str] = None,
+                  outputdir: Optional[str] = None):
+    if lang is not None:
+        global global_asr_language
+        global_asr_language = lang
+    if fs is not None:
+        global global_sample_rate
+        global_sample_rate = fs
+
+    # force use CPU if data type is bytes
+    if isinstance(audio_lists, bytes):
+        num_workers = 0
+        ngpu = 0
+
+    return inference(output_dir=outputdir,
+                     maxlenratio=maxlenratio,
+                     minlenratio=minlenratio,
+                     batch_size=batch_size,
+                     dtype=dtype,
+                     beam_size=beam_size,
+                     ngpu=ngpu,
+                     seed=seed,
+                     ctc_weight=ctc_weight,
+                     lm_weight=lm_weight,
+                     ngram_weight=ngram_weight,
+                     penalty=penalty,
+                     nbest=nbest,
+                     num_workers=num_workers,
+                     log_level=log_level,
+                     data_path_and_name_and_type=name_and_type,
+                     audio_lists=audio_lists,
+                     key_file=key_file,
+                     asr_train_config=asr_train_config,
+                     asr_model_file=asr_model_file,
+                     lm_train_config=lm_train_config,
+                     lm_file=lm_file,
+                     word_lm_train_config=word_lm_train_config,
+                     word_lm_file=word_lm_file,
+                     ngram_file=ngram_file,
+                     model_tag=model_tag,
+                     token_type=token_type,
+                     bpemodel=bpemodel,
+                     allow_variable_data_keys=allow_variable_data_keys,
+                     transducer_conf=transducer_conf,
+                     streaming=streaming,
+                     frontend_conf=frontend_conf)
+
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="ASR Decoding",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Note(kamo): Use '_' instead of '-' as separator.
+    # '-' is confusing if written in yaml.
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, required=True)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        required=True,
+        action="append",
+    )
+    group.add_argument("--audio_lists", type=list, default=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+    group.add_argument("--key_file", type=str_or_none)
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument(
+        "--asr_train_config",
+        type=str,
+        help="ASR training configuration",
+    )
+    group.add_argument(
+        "--asr_model_file",
+        type=str,
+        help="ASR model parameter file",
+    )
+    group.add_argument(
+        "--lm_train_config",
+        type=str,
+        help="LM training configuration",
+    )
+    group.add_argument(
+        "--lm_file",
+        type=str,
+        help="LM parameter file",
+    )
+    group.add_argument(
+        "--word_lm_train_config",
+        type=str,
+        help="Word LM training configuration",
+    )
+    group.add_argument(
+        "--word_lm_file",
+        type=str,
+        help="Word LM parameter file",
+    )
+    group.add_argument(
+        "--ngram_file",
+        type=str,
+        help="N-gram parameter file",
+    )
+    group.add_argument(
+        "--model_tag",
+        type=str,
+        help="Pretrained model tag. If specify this option, *_train_config and "
+             "*_file will be overwritten",
+    )
+
+    group = parser.add_argument_group("Beam-search related")
+    group.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+    group.add_argument(
+        "--maxlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain max output length. "
+             "If maxlenratio=0.0 (default), it uses a end-detect "
+             "function "
+             "to automatically find maximum hypothesis lengths."
+             "If maxlenratio<0.0, its absolute value is interpreted"
+             "as a constant max output length",
+    )
+    group.add_argument(
+        "--minlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain min output length",
+    )
+    group.add_argument(
+        "--ctc_weight",
+        type=float,
+        default=0.5,
+        help="CTC weight in joint decoding",
+    )
+    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+    group.add_argument("--streaming", type=str2bool, default=False)
+
+    group.add_argument(
+        "--asr_model_config",
+        default=None,
+        help="",
+    )
+
+    group = parser.add_argument_group("Text converter related")
+    group.add_argument(
+        "--token_type",
+        type=str_or_none,
+        default=None,
+        choices=["char", "bpe", None],
+        help="The token type for ASR model. "
+             "If not given, refers from the training args",
+    )
+    group.add_argument(
+        "--bpemodel",
+        type=str_or_none,
+        default=None,
+        help="The model path of sentencepiece. "
+             "If not given, refers from the training args",
+    )
+
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    kwargs.pop("config", None)
+    inference(**kwargs)
+
+
+if __name__ == "__main__":
+    main()

+ 6 - 1
funasr/bin/modelscope_infer.py

@@ -15,6 +15,10 @@ if __name__ == '__main__':
                         type=str,
                         default="speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                         help="model name in modelscope")
+    parser.add_argument("--model_revision",
+                        type=str,
+                        default="v1.0.3",
+                        help="model revision in modelscope")
     parser.add_argument("--local_model_path",
                         type=str,
                         default=None,
@@ -62,7 +66,8 @@ if __name__ == '__main__':
     if args.local_model_path is None:
         inference_pipeline = pipeline(
             task=Tasks.auto_speech_recognition,
-            model="damo/{}".format(args.model_name))
+            model="damo/{}".format(args.model_name),
+            model_revision=args.model_revision)
     else:
         inference_pipeline = pipeline(
             task=Tasks.auto_speech_recognition,

+ 349 - 0
funasr/datasets/iterable_dataset_modelscope.py

@@ -0,0 +1,349 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Part of the implementation is borrowed from espnet/espnet.
+"""Iterable dataset module."""
+import copy
+from io import StringIO
+from pathlib import Path
+from typing import Callable, Collection, Dict, Iterator, Tuple, Union
+
+import kaldiio
+import numpy as np
+import soundfile
+import torch
+from funasr.datasets.dataset import ESPnetDataset
+from torch.utils.data.dataset import IterableDataset
+from typeguard import check_argument_types
+
+from funasr.utils import wav_utils
+
+
+def load_kaldi(input):
+    retval = kaldiio.load_mat(input)
+    if isinstance(retval, tuple):
+        assert len(retval) == 2, len(retval)
+        if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
+            # sound scp case
+            rate, array = retval
+        elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
+            # Extended ark format case
+            array, rate = retval
+        else:
+            raise RuntimeError(
+                f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
+
+        # Multichannel wave fie
+        # array: (NSample, Channel) or (Nsample)
+
+    else:
+        # Normal ark case
+        assert isinstance(retval, np.ndarray), type(retval)
+        array = retval
+    return array
+
+
+DATA_TYPES = {
+    'sound':
+    lambda x: soundfile.read(x)[0],
+    'kaldi_ark':
+    load_kaldi,
+    'npy':
+    np.load,
+    'text_int':
+    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
+    'csv_int':
+    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
+    'text_float':
+    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
+                         ),
+    'csv_float':
+    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
+                         ),
+    'text':
+    lambda x: x,
+}
+
+
+class IterableESPnetDatasetModelScope(IterableDataset):
+    """Pytorch Dataset class for ESPNet.
+
+    Examples:
+        >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
+        ...                                  ('token_int', 'output', 'text_int')],
+        ...                                )
+        >>> for uid, data in dataset:
+        ...     data
+        {'input': per_utt_array, 'output': per_utt_array}
+    """
+    def __init__(self,
+                 path_name_type_list: Collection[Tuple[any, str, str]],
+                 preprocess: Callable[[str, Dict[str, np.ndarray]],
+                                      Dict[str, np.ndarray]] = None,
+                 float_dtype: str = 'float32',
+                 int_dtype: str = 'long',
+                 key_file: str = None,
+                 sample_rate: Union[dict, int] = 16000):
+        assert check_argument_types()
+        if len(path_name_type_list) == 0:
+            raise ValueError(
+                '1 or more elements are required for "path_name_type_list"')
+
+        self.preprocess = preprocess
+
+        self.float_dtype = float_dtype
+        self.int_dtype = int_dtype
+        self.key_file = key_file
+        self.sample_rate = sample_rate
+
+        self.debug_info = {}
+        non_iterable_list = []
+        self.path_name_type_list = []
+
+        path_list = path_name_type_list[0]
+        name = path_name_type_list[1]
+        _type = path_name_type_list[2]
+        if name in self.debug_info:
+            raise RuntimeError(f'"{name}" is duplicated for data-key')
+        self.debug_info[name] = path_list, _type
+        #        for path, name, _type in path_name_type_list:
+        for path in path_list:
+            self.path_name_type_list.append((path, name, _type))
+
+        if len(non_iterable_list) != 0:
+            # Some types doesn't support iterable mode
+            self.non_iterable_dataset = ESPnetDataset(
+                path_name_type_list=non_iterable_list,
+                preprocess=preprocess,
+                float_dtype=float_dtype,
+                int_dtype=int_dtype,
+            )
+        else:
+            self.non_iterable_dataset = None
+
+        self.apply_utt2category = False
+
+    def has_name(self, name) -> bool:
+        return name in self.debug_info
+
+    def names(self) -> Tuple[str, ...]:
+        return tuple(self.debug_info)
+
+    def __repr__(self):
+        _mes = self.__class__.__name__
+        _mes += '('
+        for name, (path, _type) in self.debug_info.items():
+            _mes += f'\n  {name}: {{"path": "{path}", "type": "{_type}"}}'
+        _mes += f'\n  preprocess: {self.preprocess})'
+        return _mes
+
+    def __iter__(
+            self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
+        torch.set_printoptions(profile='default')
+        count = len(self.path_name_type_list)
+        for idx in range(count):
+            # 2. Load the entry from each line and create a dict
+            data = {}
+            # 2.a. Load data streamingly
+
+            # value:  /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
+            value = self.path_name_type_list[idx][0]['file']
+            uid = self.path_name_type_list[idx][0]['key']
+            # name:  speech
+            name = self.path_name_type_list[idx][1]
+            _type = self.path_name_type_list[idx][2]
+            func = DATA_TYPES[_type]
+            array = func(value)
+
+            # 2.b. audio resample
+            if _type == 'sound':
+                audio_sr: int = 16000
+                model_sr: int = 16000
+                if isinstance(self.sample_rate, int):
+                    model_sr = self.sample_rate
+                else:
+                    if 'audio_sr' in self.sample_rate:
+                        audio_sr = self.sample_rate['audio_sr']
+                    if 'model_sr' in self.sample_rate:
+                        model_sr = self.sample_rate['model_sr']
+                array = wav_utils.torch_resample(array, audio_sr, model_sr)
+
+            # array:  [ 1.25122070e-03  ... ]
+            data[name] = array
+
+            # 3. [Option] Apply preprocessing
+            #   e.g. espnet2.train.preprocessor:CommonPreprocessor
+            if self.preprocess is not None:
+                data = self.preprocess(uid, data)
+                # data:  {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
+
+            # 4. Force data-precision
+            for name in data:
+                # value is np.ndarray data
+                value = data[name]
+                if not isinstance(value, np.ndarray):
+                    raise RuntimeError(
+                        f'All values must be converted to np.ndarray object '
+                        f'by preprocessing, but "{name}" is still {type(value)}.'
+                    )
+
+                # Cast to desired type
+                if value.dtype.kind == 'f':
+                    value = value.astype(self.float_dtype)
+                elif value.dtype.kind == 'i':
+                    value = value.astype(self.int_dtype)
+                else:
+                    raise NotImplementedError(
+                        f'Not supported dtype: {value.dtype}')
+                data[name] = value
+
+            yield uid, data
+
+        if count == 0:
+            raise RuntimeError('No iteration')
+
+
+class IterableESPnetBytesModelScope(IterableDataset):
+    """Pytorch audio bytes class for ESPNet.
+
+    Examples:
+        >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
+        ...                                ('token_int', 'output', 'text_int')],
+        ...                                )
+        >>> for uid, data in dataset:
+        ...     data
+        {'input': per_utt_array, 'output': per_utt_array}
+    """
+    def __init__(self,
+                 path_name_type_list: Collection[Tuple[any, str, str]],
+                 preprocess: Callable[[str, Dict[str, np.ndarray]],
+                                      Dict[str, np.ndarray]] = None,
+                 float_dtype: str = 'float32',
+                 int_dtype: str = 'long',
+                 key_file: str = None,
+                 sample_rate: Union[dict, int] = 16000):
+        assert check_argument_types()
+        if len(path_name_type_list) == 0:
+            raise ValueError(
+                '1 or more elements are required for "path_name_type_list"')
+
+        self.preprocess = preprocess
+
+        self.float_dtype = float_dtype
+        self.int_dtype = int_dtype
+        self.key_file = key_file
+        self.sample_rate = sample_rate
+
+        self.debug_info = {}
+        non_iterable_list = []
+        self.path_name_type_list = []
+
+        audio_data = path_name_type_list[0]
+        name = path_name_type_list[1]
+        _type = path_name_type_list[2]
+        if name in self.debug_info:
+            raise RuntimeError(f'"{name}" is duplicated for data-key')
+        self.debug_info[name] = audio_data, _type
+        self.path_name_type_list.append((audio_data, name, _type))
+
+        if len(non_iterable_list) != 0:
+            # Some types doesn't support iterable mode
+            self.non_iterable_dataset = ESPnetDataset(
+                path_name_type_list=non_iterable_list,
+                preprocess=preprocess,
+                float_dtype=float_dtype,
+                int_dtype=int_dtype,
+            )
+        else:
+            self.non_iterable_dataset = None
+
+        self.apply_utt2category = False
+
+        if float_dtype == 'float32':
+            self.np_dtype = np.float32
+
+    def has_name(self, name) -> bool:
+        return name in self.debug_info
+
+    def names(self) -> Tuple[str, ...]:
+        return tuple(self.debug_info)
+
+    def __repr__(self):
+        _mes = self.__class__.__name__
+        _mes += '('
+        for name, (path, _type) in self.debug_info.items():
+            _mes += f'\n  {name}: {{"path": "{path}", "type": "{_type}"}}'
+        _mes += f'\n  preprocess: {self.preprocess})'
+        return _mes
+
+    def __iter__(
+            self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
+
+        torch.set_printoptions(profile='default')
+        # 2. Load the entry from each line and create a dict
+        data = {}
+        # 2.a. Load data streamingly
+
+        value = self.path_name_type_list[0][0]
+        uid = 'pcm_data'
+        # name:  speech
+        name = self.path_name_type_list[0][1]
+        _type = self.path_name_type_list[0][2]
+        func = DATA_TYPES[_type]
+        # array:  [ 1.25122070e-03  ... ]
+        #        data[name] = np.frombuffer(value, dtype=self.np_dtype)
+
+        # 2.b. byte(PCM16) to float32
+        middle_data = np.frombuffer(value, dtype=np.int16)
+        middle_data = np.asarray(middle_data)
+        if middle_data.dtype.kind not in 'iu':
+            raise TypeError("'middle_data' must be an array of integers")
+        dtype = np.dtype('float32')
+        if dtype.kind != 'f':
+            raise TypeError("'dtype' must be a floating point type")
+
+        i = np.iinfo(middle_data.dtype)
+        abs_max = 2**(i.bits - 1)
+        offset = i.min + abs_max
+        array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
+                              dtype=self.np_dtype)
+
+        # 2.c. audio resample
+        if _type == 'sound':
+            audio_sr: int = 16000
+            model_sr: int = 16000
+            if isinstance(self.sample_rate, int):
+                model_sr = self.sample_rate
+            else:
+                if 'audio_sr' in self.sample_rate:
+                    audio_sr = self.sample_rate['audio_sr']
+                if 'model_sr' in self.sample_rate:
+                    model_sr = self.sample_rate['model_sr']
+            array = wav_utils.torch_resample(array, audio_sr, model_sr)
+
+        data[name] = array
+
+        # 3. [Option] Apply preprocessing
+        #   e.g. espnet2.train.preprocessor:CommonPreprocessor
+        if self.preprocess is not None:
+            data = self.preprocess(uid, data)
+            # data:  {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
+
+        # 4. Force data-precision
+        for name in data:
+            # value is np.ndarray data
+            value = data[name]
+            if not isinstance(value, np.ndarray):
+                raise RuntimeError(
+                    f'All values must be converted to np.ndarray object '
+                    f'by preprocessing, but "{name}" is still {type(value)}.')
+
+            # Cast to desired type
+            if value.dtype.kind == 'f':
+                value = value.astype(self.float_dtype)
+            elif value.dtype.kind == 'i':
+                value = value.astype(self.int_dtype)
+            else:
+                raise NotImplementedError(
+                    f'Not supported dtype: {value.dtype}')
+            data[name] = value
+
+        yield uid, data

+ 2 - 3
funasr/models/e2e_asr_paraformer.py

@@ -330,9 +330,10 @@ class Paraformer(AbsESPnetModel):
 
 	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
 
-		decoder_out, _ = self.decoder(
+		decoder_outs = self.decoder(
 			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
 		)
+		decoder_out = decoder_outs[0]
 		decoder_out = torch.log_softmax(decoder_out, dim=-1)
 		return decoder_out, ys_pad_lens
 
@@ -553,7 +554,6 @@ class ParaformerBert(Paraformer):
 		postencoder: Optional[AbsPostEncoder],
 		decoder: AbsDecoder,
 		ctc: CTC,
-		joint_network: Optional[torch.nn.Module],
 		ctc_weight: float = 0.5,
 		interctc_weight: float = 0.0,
 		ignore_id: int = -1,
@@ -590,7 +590,6 @@ class ParaformerBert(Paraformer):
 		postencoder=postencoder,
 		decoder=decoder,
 		ctc=ctc,
-		joint_network=joint_network,
 		ctc_weight=ctc_weight,
 		interctc_weight=interctc_weight,
 		ignore_id=ignore_id,

+ 155 - 0
funasr/models/frontend/wav_frontend.py

@@ -0,0 +1,155 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Part of the implementation is borrowed from espnet/espnet.
+
+import copy
+from typing import Optional, Tuple, Union
+
+import humanfriendly
+import numpy as np
+import torch
+import torchaudio.compliance.kaldi as kaldi
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.layers.log_mel import LogMel
+from funasr.layers.stft import Stft
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.modules.frontends.frontend import Frontend
+from typeguard import check_argument_types
+
+
+def apply_cmvn(inputs, mvn):  # noqa
+    """
+    Apply CMVN with mvn data
+    """
+
+    device = inputs.device
+    dtype = inputs.dtype
+    frame, dim = inputs.shape
+
+    meams = np.tile(mvn[0:1, :dim], (frame, 1))
+    vars = np.tile(mvn[1:2, :dim], (frame, 1))
+    inputs += torch.from_numpy(meams).type(dtype).to(device)
+    inputs *= torch.from_numpy(vars).type(dtype).to(device)
+
+    return inputs.type(torch.float32)
+
+
+def apply_lfr(inputs, lfr_m, lfr_n):
+    LFR_inputs = []
+    T = inputs.shape[0]
+    T_lfr = int(np.ceil(T / lfr_n))
+    left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1)
+    inputs = torch.vstack((left_padding, inputs))
+    T = T + (lfr_m - 1) // 2
+    for i in range(T_lfr):
+        if lfr_m <= T - i * lfr_n:
+            LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
+        else:  # process last LFR frame
+            num_padding = lfr_m - (T - i * lfr_n)
+            frame = (inputs[i * lfr_n:]).view(-1)
+            for _ in range(num_padding):
+                frame = torch.hstack((frame, inputs[-1]))
+            LFR_inputs.append(frame)
+    LFR_outputs = torch.vstack(LFR_inputs)
+    return LFR_outputs.type(torch.float32)
+
+
+class WavFrontend(AbsFrontend):
+    """Conventional frontend structure for ASR.
+    """
+    def __init__(
+        self,
+        fs: Union[int, str] = 16000,
+        n_fft: int = 512,
+        win_length: int = 400,
+        hop_length: int = 160,
+        window: Optional[str] = 'hamming',
+        center: bool = True,
+        normalized: bool = False,
+        onesided: bool = True,
+        n_mels: int = 80,
+        fmin: int = None,
+        fmax: int = None,
+        lfr_m: int = 1,
+        lfr_n: int = 1,
+        htk: bool = False,
+        mvn_data=None,
+        frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+        apply_stft: bool = True,
+    ):
+        assert check_argument_types()
+        super().__init__()
+        if isinstance(fs, str):
+            fs = humanfriendly.parse_size(fs)
+
+        # Deepcopy (In general, dict shouldn't be used as default arg)
+        frontend_conf = copy.deepcopy(frontend_conf)
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.window = window
+        self.fs = fs
+        self.mvn_data = mvn_data
+        self.lfr_m = lfr_m
+        self.lfr_n = lfr_n
+
+        if apply_stft:
+            self.stft = Stft(
+                n_fft=n_fft,
+                win_length=win_length,
+                hop_length=hop_length,
+                center=center,
+                window=window,
+                normalized=normalized,
+                onesided=onesided,
+            )
+        else:
+            self.stft = None
+        self.apply_stft = apply_stft
+
+        if frontend_conf is not None:
+            self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
+        else:
+            self.frontend = None
+
+        self.logmel = LogMel(
+            fs=fs,
+            n_fft=n_fft,
+            n_mels=n_mels,
+            fmin=fmin,
+            fmax=fmax,
+            htk=htk,
+        )
+        self.n_mels = n_mels
+        self.frontend_type = 'default'
+
+    def output_size(self) -> int:
+        return self.n_mels
+
+    def forward(
+            self, input: torch.Tensor,
+            input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+
+        sample_frequency = self.fs
+        num_mel_bins = self.n_mels
+        frame_length = self.win_length * 1000 / sample_frequency
+        frame_shift = self.hop_length * 1000 / sample_frequency
+
+        waveform = input * (1 << 15)
+
+        mat = kaldi.fbank(waveform,
+                          num_mel_bins=num_mel_bins,
+                          frame_length=frame_length,
+                          frame_shift=frame_shift,
+                          dither=1.0,
+                          energy_floor=0.0,
+                          window_type=self.window,
+                          sample_frequency=sample_frequency)
+        if self.lfr_m != 1 or self.lfr_n != 1:
+            mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
+        if self.mvn_data is not None:
+            mat = apply_cmvn(mat, self.mvn_data)
+
+        input_feats = mat[None, :]
+        feats_lens = torch.randn(1)
+        feats_lens.fill_(input_feats.shape[1])
+
+        return input_feats, feats_lens

+ 1 - 1
funasr/models/predictor/cif.py

@@ -4,7 +4,7 @@ from torch import nn
 from funasr.modules.nets_utils import make_pad_mask
 
 class CifPredictor(nn.Module):
-    def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0):
+    def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
         super(CifPredictor, self).__init__()
 
         self.pad = nn.ConstantPad1d((l_order, r_order), 0)

+ 60 - 1
funasr/tasks/abs_task.py

@@ -38,6 +38,7 @@ from funasr.datasets.dataset import AbsDataset
 from funasr.datasets.dataset import DATA_TYPES
 from funasr.datasets.dataset import ESPnetDataset
 from funasr.datasets.iterable_dataset import IterableESPnetDataset
+from funasr.datasets.iterable_dataset_modelscope import IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope
 from funasr.iterators.abs_iter_factory import AbsIterFactory
 from funasr.iterators.chunk_iter_factory import ChunkIterFactory
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
@@ -1026,7 +1027,7 @@ class AbsTask(ABC):
     @classmethod
     def check_task_requirements(
             cls,
-            dataset: Union[AbsDataset, IterableESPnetDataset],
+            dataset: Union[AbsDataset, IterableESPnetDataset, IterableESPnetDatasetModelScope, IterableESPnetBytesModelScope],
             allow_variable_data_keys: bool,
             train: bool,
             inference: bool = False,
@@ -1748,6 +1749,64 @@ class AbsTask(ABC):
             **kwargs,
         )
 
+    @classmethod
+    def build_streaming_iterator_modelscope(
+            cls,
+            data_path_and_name_and_type,
+            preprocess_fn,
+            collate_fn,
+            key_file: str = None,
+            batch_size: int = 1,
+            dtype: str = np.float32,
+            num_workers: int = 1,
+            allow_variable_data_keys: bool = False,
+            ngpu: int = 0,
+            inference: bool = False,
+            sample_rate: Union[dict, int] = 16000
+    ) -> DataLoader:
+        """Build DataLoader using iterable dataset"""
+        assert check_argument_types()
+        # For backward compatibility for pytorch DataLoader
+        if collate_fn is not None:
+            kwargs = dict(collate_fn=collate_fn)
+        else:
+            kwargs = {}
+
+        audio_data = data_path_and_name_and_type[0]
+        if isinstance(audio_data, bytes):
+            dataset = IterableESPnetBytesModelScope(
+                data_path_and_name_and_type,
+                float_dtype=dtype,
+                preprocess=preprocess_fn,
+                key_file=key_file,
+                sample_rate=sample_rate
+            )
+        else:
+            dataset = IterableESPnetDatasetModelScope(
+                data_path_and_name_and_type,
+                float_dtype=dtype,
+                preprocess=preprocess_fn,
+                key_file=key_file,
+                sample_rate=sample_rate
+            )
+
+        if dataset.apply_utt2category:
+            kwargs.update(batch_size=1)
+        else:
+            kwargs.update(batch_size=batch_size)
+
+        cls.check_task_requirements(dataset,
+                                    allow_variable_data_keys,
+                                    train=False,
+                                    inference=inference)
+
+        return DataLoader(
+            dataset=dataset,
+            pin_memory=ngpu > 0,
+            num_workers=num_workers,
+            **kwargs,
+        )
+
     # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
     @classmethod
     def build_model_from_file(

+ 85 - 0
funasr/utils/asr_env_checking.py

@@ -0,0 +1,85 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import shutil
+import ssl
+
+import nltk
+
+# mkdir nltk_data dir if not exist
+try:
+    nltk.data.find('.')
+except LookupError:
+    dir_list = nltk.data.path
+    for dir_item in dir_list:
+        if not os.path.exists(dir_item):
+            os.mkdir(dir_item)
+        if os.path.exists(dir_item):
+            break
+
+# download one package if nltk_data not exist
+try:
+    nltk.data.find('.')
+except:  # noqa: *
+    try:
+        _create_unverified_https_context = ssl._create_unverified_context
+    except AttributeError:
+        pass
+    else:
+        ssl._create_default_https_context = _create_unverified_https_context
+
+    nltk.download('cmudict', halt_on_error=False, raise_on_error=True)
+
+# deploy taggers/averaged_perceptron_tagger
+try:
+    nltk.data.find('taggers/averaged_perceptron_tagger')
+except:  # noqa: *
+    data_dir = nltk.data.find('.')
+    target_dir = os.path.join(data_dir, 'taggers')
+    if not os.path.exists(target_dir):
+        os.mkdir(target_dir)
+    src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
+                            'averaged_perceptron_tagger.zip')
+    shutil.copyfile(src_file,
+                    os.path.join(target_dir, 'averaged_perceptron_tagger.zip'))
+    shutil._unpack_zipfile(
+        os.path.join(target_dir, 'averaged_perceptron_tagger.zip'), target_dir)
+
+# deploy corpora/cmudict
+try:
+    nltk.data.find('corpora/cmudict')
+except:  # noqa: *
+    data_dir = nltk.data.find('.')
+    target_dir = os.path.join(data_dir, 'corpora')
+    if not os.path.exists(target_dir):
+        os.mkdir(target_dir)
+    src_file = os.path.join(os.path.dirname(__file__), '..', 'nltk_packages',
+                            'cmudict.zip')
+    shutil.copyfile(src_file, os.path.join(target_dir, 'cmudict.zip'))
+    shutil._unpack_zipfile(os.path.join(target_dir, 'cmudict.zip'), target_dir)
+
+try:
+    nltk.data.find('taggers/averaged_perceptron_tagger')
+except:  # noqa: *
+    try:
+        _create_unverified_https_context = ssl._create_unverified_context
+    except AttributeError:
+        pass
+    else:
+        ssl._create_default_https_context = _create_unverified_https_context
+
+    nltk.download('averaged_perceptron_tagger',
+                  halt_on_error=False,
+                  raise_on_error=True)
+
+try:
+    nltk.data.find('corpora/cmudict')
+except:  # noqa: *
+    try:
+        _create_unverified_https_context = ssl._create_unverified_context
+    except AttributeError:
+        pass
+    else:
+        ssl._create_default_https_context = _create_unverified_https_context
+
+    nltk.download('cmudict', halt_on_error=False, raise_on_error=True)

+ 327 - 0
funasr/utils/asr_utils.py

@@ -0,0 +1,327 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import os
+import struct
+from typing import Any, Dict, List, Union
+
+import librosa
+import numpy as np
+import pkg_resources
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+green_color = '\033[1;32m'
+red_color = '\033[0;31;40m'
+yellow_color = '\033[0;33;40m'
+end_color = '\033[0m'
+
+global_asr_language = 'zh-cn'
+
+
+def get_version():
+    return float(pkg_resources.get_distribution('easyasr').version)
+
+
+def sample_rate_checking(audio_in: Union[str, bytes], audio_format: str):
+    r_audio_fs = None
+
+    if audio_format == 'wav':
+        r_audio_fs = get_sr_from_wav(audio_in)
+    elif audio_format == 'pcm' and isinstance(audio_in, bytes):
+        r_audio_fs = get_sr_from_bytes(audio_in)
+
+    return r_audio_fs
+
+
+def type_checking(audio_in: Union[str, bytes],
+                  audio_fs: int = None,
+                  recog_type: str = None,
+                  audio_format: str = None):
+    r_recog_type = recog_type
+    r_audio_format = audio_format
+    r_wav_path = audio_in
+
+    if isinstance(audio_in, str):
+        assert os.path.exists(audio_in), f'wav_path:{audio_in} does not exist'
+    elif isinstance(audio_in, bytes):
+        assert len(audio_in) > 0, 'audio in is empty'
+        r_audio_format = 'pcm'
+        r_recog_type = 'wav'
+
+    if r_recog_type is None:
+        # audio_in is wav, recog_type is wav_file
+        if os.path.isfile(audio_in):
+            if audio_in.endswith('.wav') or audio_in.endswith('.WAV'):
+                r_recog_type = 'wav'
+                r_audio_format = 'wav'
+
+        # recog_type is datasets_file
+        elif os.path.isdir(audio_in):
+            dir_name = os.path.basename(audio_in)
+            if 'test' in dir_name:
+                r_recog_type = 'test'
+            elif 'dev' in dir_name:
+                r_recog_type = 'dev'
+            elif 'train' in dir_name:
+                r_recog_type = 'train'
+
+    if r_audio_format is None:
+        if find_file_by_ends(audio_in, '.ark'):
+            r_audio_format = 'kaldi_ark'
+        elif find_file_by_ends(audio_in, '.wav') or find_file_by_ends(
+                audio_in, '.WAV'):
+            r_audio_format = 'wav'
+        elif find_file_by_ends(audio_in, '.records'):
+            r_audio_format = 'tfrecord'
+
+    if r_audio_format == 'kaldi_ark' and r_recog_type != 'wav':
+        # datasets with kaldi_ark file
+        r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
+    elif r_audio_format == 'tfrecord' and r_recog_type != 'wav':
+        # datasets with tensorflow records file
+        r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../'))
+    elif r_audio_format == 'wav' and r_recog_type != 'wav':
+        # datasets with waveform files
+        r_wav_path = os.path.abspath(os.path.join(r_wav_path, '../../'))
+
+    return r_recog_type, r_audio_format, r_wav_path
+
+
+def get_sr_from_bytes(wav: bytes):
+    sr = None
+    data = wav
+    if len(data) > 44:
+        try:
+            header_fields = {}
+            header_fields['ChunkID'] = str(data[0:4], 'UTF-8')
+            header_fields['Format'] = str(data[8:12], 'UTF-8')
+            header_fields['Subchunk1ID'] = str(data[12:16], 'UTF-8')
+            if header_fields['ChunkID'] == 'RIFF' and header_fields[
+                    'Format'] == 'WAVE' and header_fields[
+                        'Subchunk1ID'] == 'fmt ':
+                header_fields['SampleRate'] = struct.unpack('<I',
+                                                            data[24:28])[0]
+                sr = header_fields['SampleRate']
+        except Exception:
+            # no treatment
+            pass
+    else:
+        logger.warn('audio bytes is ' + str(len(data)) + ' is invalid.')
+
+    return sr
+
+
+def get_sr_from_wav(fname: str):
+    fs = None
+    if os.path.isfile(fname):
+        audio, fs = librosa.load(fname, sr=None)
+        return fs
+    elif os.path.isdir(fname):
+        dir_files = os.listdir(fname)
+        for file in dir_files:
+            file_path = os.path.join(fname, file)
+            if os.path.isfile(file_path):
+                if file_path.endswith('.wav') or file_path.endswith('.WAV'):
+                    fs = get_sr_from_wav(file_path)
+            elif os.path.isdir(file_path):
+                fs = get_sr_from_wav(file_path)
+
+            if fs is not None:
+                break
+
+    return fs
+
+
+def find_file_by_ends(dir_path: str, ends: str):
+    dir_files = os.listdir(dir_path)
+    for file in dir_files:
+        file_path = os.path.join(dir_path, file)
+        if os.path.isfile(file_path):
+            if file_path.endswith(ends):
+                return True
+        elif os.path.isdir(file_path):
+            if find_file_by_ends(file_path, ends):
+                return True
+
+    return False
+
+
+def recursion_dir_all_wav(wav_list, dir_path: str) -> List[str]:
+    dir_files = os.listdir(dir_path)
+    for file in dir_files:
+        file_path = os.path.join(dir_path, file)
+        if os.path.isfile(file_path):
+            if file_path.endswith('.wav') or file_path.endswith('.WAV'):
+                wav_list.append(file_path)
+        elif os.path.isdir(file_path):
+            recursion_dir_all_wav(wav_list, file_path)
+
+    return wav_list
+
+
+def set_parameters(language: str = None):
+    if language is not None:
+        global global_asr_language
+        global_asr_language = language
+
+
+def compute_wer(hyp_list: List[Any],
+                ref_list: List[Any],
+                lang: str = None) -> Dict[str, Any]:
+    assert len(hyp_list) > 0, 'hyp list is empty'
+    assert len(ref_list) > 0, 'ref list is empty'
+
+    if lang is not None:
+        global global_asr_language
+        global_asr_language = lang
+
+    rst = {
+        'Wrd': 0,
+        'Corr': 0,
+        'Ins': 0,
+        'Del': 0,
+        'Sub': 0,
+        'Snt': 0,
+        'Err': 0.0,
+        'S.Err': 0.0,
+        'wrong_words': 0,
+        'wrong_sentences': 0
+    }
+
+    for h_item in hyp_list:
+        for r_item in ref_list:
+            if h_item['key'] == r_item['key']:
+                out_item = compute_wer_by_line(h_item['value'],
+                                               r_item['value'],
+                                               global_asr_language)
+                rst['Wrd'] += out_item['nwords']
+                rst['Corr'] += out_item['cor']
+                rst['wrong_words'] += out_item['wrong']
+                rst['Ins'] += out_item['ins']
+                rst['Del'] += out_item['del']
+                rst['Sub'] += out_item['sub']
+                rst['Snt'] += 1
+                if out_item['wrong'] > 0:
+                    rst['wrong_sentences'] += 1
+                    print_wrong_sentence(key=h_item['key'],
+                                         hyp=h_item['value'],
+                                         ref=r_item['value'])
+                else:
+                    print_correct_sentence(key=h_item['key'],
+                                           hyp=h_item['value'],
+                                           ref=r_item['value'])
+
+                break
+
+    if rst['Wrd'] > 0:
+        rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
+    if rst['Snt'] > 0:
+        rst['S.Err'] = round(rst['wrong_sentences'] * 100 / rst['Snt'], 2)
+
+    return rst
+
+
+def compute_wer_by_line(hyp: List[str],
+                        ref: List[str],
+                        lang: str = 'zh-cn') -> Dict[str, Any]:
+    if lang != 'zh-cn':
+        hyp = hyp.split()
+        ref = ref.split()
+
+    hyp = list(map(lambda x: x.lower(), hyp))
+    ref = list(map(lambda x: x.lower(), ref))
+
+    len_hyp = len(hyp)
+    len_ref = len(ref)
+
+    cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
+
+    ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
+
+    for i in range(len_hyp + 1):
+        cost_matrix[i][0] = i
+    for j in range(len_ref + 1):
+        cost_matrix[0][j] = j
+
+    for i in range(1, len_hyp + 1):
+        for j in range(1, len_ref + 1):
+            if hyp[i - 1] == ref[j - 1]:
+                cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
+            else:
+                substitution = cost_matrix[i - 1][j - 1] + 1
+                insertion = cost_matrix[i - 1][j] + 1
+                deletion = cost_matrix[i][j - 1] + 1
+
+                compare_val = [substitution, insertion, deletion]
+
+                min_val = min(compare_val)
+                operation_idx = compare_val.index(min_val) + 1
+                cost_matrix[i][j] = min_val
+                ops_matrix[i][j] = operation_idx
+
+    match_idx = []
+    i = len_hyp
+    j = len_ref
+    rst = {
+        'nwords': len_ref,
+        'cor': 0,
+        'wrong': 0,
+        'ins': 0,
+        'del': 0,
+        'sub': 0
+    }
+    while i >= 0 or j >= 0:
+        i_idx = max(0, i)
+        j_idx = max(0, j)
+
+        if ops_matrix[i_idx][j_idx] == 0:  # correct
+            if i - 1 >= 0 and j - 1 >= 0:
+                match_idx.append((j - 1, i - 1))
+                rst['cor'] += 1
+
+            i -= 1
+            j -= 1
+
+        elif ops_matrix[i_idx][j_idx] == 2:  # insert
+            i -= 1
+            rst['ins'] += 1
+
+        elif ops_matrix[i_idx][j_idx] == 3:  # delete
+            j -= 1
+            rst['del'] += 1
+
+        elif ops_matrix[i_idx][j_idx] == 1:  # substitute
+            i -= 1
+            j -= 1
+            rst['sub'] += 1
+
+        if i < 0 and j >= 0:
+            rst['del'] += 1
+        elif j < 0 and i >= 0:
+            rst['ins'] += 1
+
+    match_idx.reverse()
+    wrong_cnt = cost_matrix[len_hyp][len_ref]
+    rst['wrong'] = wrong_cnt
+
+    return rst
+
+
+def print_wrong_sentence(key: str, hyp: str, ref: str):
+    space = len(key)
+    print(key + yellow_color + ' ref: ' + ref)
+    print(' ' * space + red_color + ' hyp: ' + hyp + end_color)
+
+
+def print_correct_sentence(key: str, hyp: str, ref: str):
+    space = len(key)
+    print(key + yellow_color + ' ref: ' + ref)
+    print(' ' * space + green_color + ' hyp: ' + hyp + end_color)
+
+
+def print_progress(percent):
+    if percent > 1:
+        percent = 1
+    res = int(50 * percent) * '#'
+    print('\r[%-50s] %d%%' % (res, int(100 * percent)), end='')

+ 174 - 0
funasr/utils/postprocess_utils.py

@@ -0,0 +1,174 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import string
+from typing import Any, List, Union
+
+
+def isChinese(ch: str):
+    if '\u4e00' <= ch <= '\u9fff':
+        return True
+    return False
+
+
+def isAllChinese(word: Union[List[Any], str]):
+    word_lists = []
+    table = str.maketrans('', '', string.punctuation)
+    for i in word:
+        cur = i.translate(table)
+        cur = cur.replace(' ', '')
+        cur = cur.replace('</s>', '')
+        cur = cur.replace('<s>', '')
+        word_lists.append(cur)
+
+    if len(word_lists) == 0:
+        return False
+
+    for ch in word_lists:
+        if isChinese(ch) is False:
+            return False
+    return True
+
+
+def isAllAlpha(word: Union[List[Any], str]):
+    word_lists = []
+    table = str.maketrans('', '', string.punctuation)
+    for i in word:
+        cur = i.translate(table)
+        cur = cur.replace(' ', '')
+        cur = cur.replace('</s>', '')
+        cur = cur.replace('<s>', '')
+        word_lists.append(cur)
+
+    if len(word_lists) == 0:
+        return False
+
+    for ch in word_lists:
+        if ch.isalpha() is False:
+            return False
+        elif ch.isalpha() is True and isChinese(ch) is True:
+            return False
+
+    return True
+
+
+def abbr_dispose(words: List[Any]) -> List[Any]:
+    words_size = len(words)
+    word_lists = []
+    abbr_begin = []
+    abbr_end = []
+    last_num = -1
+    for num in range(words_size):
+        if num <= last_num:
+            continue
+
+        if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
+            if num + 1 < words_size and words[
+                    num + 1] == ' ' and num + 2 < words_size and len(
+                        words[num +
+                              2]) == 1 and words[num +
+                                                 2].encode('utf-8').isalpha():
+                # found the begin of abbr
+                abbr_begin.append(num)
+                num += 2
+                abbr_end.append(num)
+                # to find the end of abbr
+                while True:
+                    num += 1
+                    if num < words_size and words[num] == ' ':
+                        num += 1
+                        if num < words_size and len(
+                                words[num]) == 1 and words[num].encode(
+                                    'utf-8').isalpha():
+                            abbr_end.pop()
+                            abbr_end.append(num)
+                            last_num = num
+                        else:
+                            break
+                    else:
+                        break
+
+    last_num = -1
+    for num in range(words_size):
+        if num <= last_num:
+            continue
+
+        if num in abbr_begin:
+            word_lists.append(words[num].upper())
+            num += 1
+            while num < words_size:
+                if num in abbr_end:
+                    word_lists.append(words[num].upper())
+                    last_num = num
+                    break
+                else:
+                    if words[num].encode('utf-8').isalpha():
+                        word_lists.append(words[num].upper())
+                num += 1
+        else:
+            word_lists.append(words[num])
+
+    return word_lists
+
+
+def sentence_postprocess(words: List[Any]):
+    middle_lists = []
+    word_lists = []
+    word_item = ''
+
+    # wash words lists
+    for i in words:
+        word = ''
+        if isinstance(i, str):
+            word = i
+        else:
+            word = i.decode('utf-8')
+
+        if word in ['<s>', '</s>', '<unk>']:
+            continue
+        else:
+            middle_lists.append(word)
+
+    # all chinese characters
+    if isAllChinese(middle_lists):
+        for ch in middle_lists:
+            word_lists.append(ch.replace(' ', ''))
+
+    # all alpha characters
+    elif isAllAlpha(middle_lists):
+        for ch in middle_lists:
+            word = ''
+            if '@@' in ch:
+                word = ch.replace('@@', '')
+                word_item += word
+            else:
+                word_item += ch
+                word_lists.append(word_item)
+                word_lists.append(' ')
+                word_item = ''
+
+    # mix characters
+    else:
+        alpha_blank = False
+        for ch in middle_lists:
+            word = ''
+            if isAllChinese(ch):
+                if alpha_blank is True:
+                    word_lists.pop()
+                word_lists.append(ch)
+                alpha_blank = False
+            elif '@@' in ch:
+                word = ch.replace('@@', '')
+                word_item += word
+                alpha_blank = False
+            elif isAllAlpha(ch):
+                word_item += ch
+                word_lists.append(word_item)
+                word_lists.append(' ')
+                word_item = ''
+                alpha_blank = True
+            else:
+                raise ValueError('invalid character: {}'.format(ch))
+
+    word_lists = abbr_dispose(word_lists)
+    sentence = ''.join(word_lists).strip()
+    return sentence

+ 178 - 0
funasr/utils/wav_utils.py

@@ -0,0 +1,178 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import math
+import os
+from typing import Any, Dict, Union
+
+import kaldiio
+import librosa
+import numpy as np
+import torch
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
+
+
+def ndarray_resample(audio_in: np.ndarray,
+                     fs_in: int = 16000,
+                     fs_out: int = 16000) -> np.ndarray:
+    audio_out = audio_in
+    if fs_in != fs_out:
+        audio_out = librosa.resample(audio_in, orig_sr=fs_in, target_sr=fs_out)
+    return audio_out
+
+
+def torch_resample(audio_in: torch.Tensor,
+                   fs_in: int = 16000,
+                   fs_out: int = 16000) -> torch.Tensor:
+    audio_out = audio_in
+    if fs_in != fs_out:
+        audio_out = torchaudio.transforms.Resample(orig_freq=fs_in,
+                                                   new_freq=fs_out)(audio_in)
+    return audio_out
+
+
+def extract_CMVN_featrures(mvn_file):
+    """
+    extract CMVN from cmvn.ark
+    """
+
+    if not os.path.exists(mvn_file):
+        return None
+    try:
+        cmvn = kaldiio.load_mat(mvn_file)
+        means = []
+        variance = []
+
+        for i in range(cmvn.shape[1] - 1):
+            means.append(float(cmvn[0][i]))
+
+        count = float(cmvn[0][-1])
+
+        for i in range(cmvn.shape[1] - 1):
+            variance.append(float(cmvn[1][i]))
+
+        for i in range(len(means)):
+            means[i] /= count
+            variance[i] = variance[i] / count - means[i] * means[i]
+            if variance[i] < 1.0e-20:
+                variance[i] = 1.0e-20
+            variance[i] = 1.0 / math.sqrt(variance[i])
+
+        cmvn = np.array([means, variance])
+        return cmvn
+    except Exception:
+        cmvn = extract_CMVN_features_txt(mvn_file)
+        return cmvn
+
+
+def extract_CMVN_features_txt(mvn_file):  # noqa
+    with open(mvn_file, 'r', encoding='utf-8') as f:
+        lines = f.readlines()
+
+    add_shift_list = []
+    rescale_list = []
+    for i in range(len(lines)):
+        line_item = lines[i].split()
+        if line_item[0] == '<AddShift>':
+            line_item = lines[i + 1].split()
+            if line_item[0] == '<LearnRateCoef>':
+                add_shift_line = line_item[3:(len(line_item) - 1)]
+                add_shift_list = list(add_shift_line)
+                continue
+        elif line_item[0] == '<Rescale>':
+            line_item = lines[i + 1].split()
+            if line_item[0] == '<LearnRateCoef>':
+                rescale_line = line_item[3:(len(line_item) - 1)]
+                rescale_list = list(rescale_line)
+                continue
+    add_shift_list_f = [float(s) for s in add_shift_list]
+    rescale_list_f = [float(s) for s in rescale_list]
+    cmvn = np.array([add_shift_list_f, rescale_list_f])
+    return cmvn
+
+
+def build_LFR_features(inputs, m=7, n=6):  # noqa
+    """
+    Actually, this implements stacking frames and skipping frames.
+    if m = 1 and n = 1, just return the origin features.
+    if m = 1 and n > 1, it works like skipping.
+    if m > 1 and n = 1, it works like stacking but only support right frames.
+    if m > 1 and n > 1, it works like LFR.
+
+    Args:
+        inputs_batch: inputs is T x D np.ndarray
+        m: number of frames to stack
+        n: number of frames to skip
+    """
+    # LFR_inputs_batch = []
+    # for inputs in inputs_batch:
+    LFR_inputs = []
+    T = inputs.shape[0]
+    T_lfr = int(np.ceil(T / n))
+    left_padding = np.tile(inputs[0], ((m - 1) // 2, 1))
+    inputs = np.vstack((left_padding, inputs))
+    T = T + (m - 1) // 2
+    for i in range(T_lfr):
+        if m <= T - i * n:
+            LFR_inputs.append(np.hstack(inputs[i * n:i * n + m]))
+        else:  # process last LFR frame
+            num_padding = m - (T - i * n)
+            frame = np.hstack(inputs[i * n:])
+            for _ in range(num_padding):
+                frame = np.hstack((frame, inputs[-1]))
+            LFR_inputs.append(frame)
+    return np.vstack(LFR_inputs)
+
+
+def compute_fbank(wav_file,
+                  num_mel_bins=80,
+                  frame_length=25,
+                  frame_shift=10,
+                  dither=0.0,
+                  is_pcm=False,
+                  fs: Union[int, Dict[Any, int]] = 16000):
+    audio_sr: int = 16000
+    model_sr: int = 16000
+    if isinstance(fs, int):
+        model_sr = fs
+        audio_sr = fs
+    else:
+        model_sr = fs['model_fs']
+        audio_sr = fs['audio_fs']
+
+    if is_pcm is True:
+        # byte(PCM16) to float32, and resample
+        value = wav_file
+        middle_data = np.frombuffer(value, dtype=np.int16)
+        middle_data = np.asarray(middle_data)
+        if middle_data.dtype.kind not in 'iu':
+            raise TypeError("'middle_data' must be an array of integers")
+        dtype = np.dtype('float32')
+        if dtype.kind != 'f':
+            raise TypeError("'dtype' must be a floating point type")
+
+        i = np.iinfo(middle_data.dtype)
+        abs_max = 2**(i.bits - 1)
+        offset = i.min + abs_max
+        waveform = np.frombuffer(
+            (middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+        waveform = ndarray_resample(waveform, audio_sr, model_sr)
+        waveform = torch.from_numpy(waveform.reshape(1, -1))
+    else:
+        # load pcm from wav, and resample
+        waveform, audio_sr = torchaudio.load(wav_file)
+        waveform = waveform * (1 << 15)
+        waveform = torch_resample(waveform, audio_sr, model_sr)
+
+    mat = kaldi.fbank(waveform,
+                      num_mel_bins=num_mel_bins,
+                      frame_length=frame_length,
+                      frame_shift=frame_shift,
+                      dither=dither,
+                      energy_floor=0.0,
+                      window_type='hamming',
+                      sample_frequency=model_sr)
+
+    input_feats = mat
+
+    return input_feats

+ 1 - 1
funasr/version.txt

@@ -1 +1 @@
-0.1.0
+0.1.3