Prechádzať zdrojové kódy

Merge branch 'dev_wjm' of https://github.com/alibaba-damo-academy/FunASR into dev_wjm

speech_asr 3 rokov pred
rodič
commit
6d17715edf
31 zmenil súbory, kde vykonal 2523 pridanie a 88 odobranie
  1. 10 1
      .github/workflows/main.yml
  2. 4 3
      README.md
  3. BIN
      docs/images/dingding.jpg
  4. BIN
      docs/images/wechat.png
  5. 53 0
      egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/README.md
  6. 40 0
      egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md
  7. 35 0
      egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py
  8. 103 0
      egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py
  9. 67 0
      egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py
  10. 1 1
      egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.py
  11. 36 0
      egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py
  12. 6 0
      funasr/bin/asr_inference_launch.py
  13. 774 0
      funasr/bin/asr_inference_mfcca.py
  14. 2 0
      funasr/bin/build_trainer.py
  15. 2 2
      funasr/datasets/iterable_dataset.py
  16. 9 19
      funasr/export/README.md
  17. 13 5
      funasr/export/export_model.py
  18. 1 1
      funasr/export/models/e2e_asr_paraformer.py
  19. 0 50
      funasr/export/models/predictor/cif.py
  20. 322 0
      funasr/models/e2e_asr_mfcca.py
  21. 270 0
      funasr/models/encoder/encoder_layer_mfcca.py
  22. 450 0
      funasr/models/encoder/mfcca_encoder.py
  23. 125 0
      funasr/models/frontend/default.py
  24. 0 3
      funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/.gitignore
  25. 12 2
      funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/README.md
  26. 9 0
      funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/demo.py
  27. 0 0
      funasr/runtime/python/torchscripts/__init__.py
  28. 0 0
      funasr/runtime/python/torchscripts/paraformer/__init__.py
  29. 9 1
      funasr/tasks/abs_task.py
  30. 138 0
      funasr/tasks/asr.py
  31. 32 0
      funasr/utils/wav_utils.py

+ 10 - 1
.github/workflows/main.yml

@@ -12,6 +12,10 @@ jobs:
     runs-on: ubuntu-latest
     steps:
       - uses: actions/checkout@v1
+      - uses: ammaraskar/sphinx-action@master
+        with:
+          docs-folder: "docs/"
+          pre-build-command: "pip install sphinx-markdown-tables nbsphinx jinja2 recommonmark sphinx_rtd_theme"
       - uses: ammaraskar/sphinx-action@master
         with:
           docs-folder: "docs_cn/"
@@ -22,7 +26,12 @@ jobs:
         run: |
           mkdir public
           touch public/.nojekyll
-          cp -r docs_cn/_build/html/* public/
+          mkdir public/en
+          touch public/en/.nojekyll
+          cp -r docs/_build/html/* public/en/
+          mkdir public/cn
+          touch public/cn/.nojekyll
+          cp -r docs_cn/_build/html/* public/cn/
 
       - name: deploy github.io pages
         if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/dev_wjm'

+ 4 - 3
README.md

@@ -1,4 +1,4 @@
-<div align="left"><img src="docs/images/funasr_logo.jpg" width="400"/></div>
+[//]: # (<div align="left"><img src="docs/images/funasr_logo.jpg" width="400"/></div>)
 
 # FunASR: A Fundamental End-to-End Speech Recognition Toolkit
 
@@ -7,7 +7,8 @@
 [**News**](https://github.com/alibaba-damo-academy/FunASR#whats-new) 
 | [**Highlights**](#highlights)
 | [**Installation**](#installation)
-| [**Docs**](https://alibaba-damo-academy.github.io/FunASR/index.html)
+| [**Docs_CN**](https://alibaba-damo-academy.github.io/FunASR/cn/index.html)
+| [**Docs_EN**](https://alibaba-damo-academy.github.io/FunASR/en/index.html)
 | [**Tutorial**](https://github.com/alibaba-damo-academy/FunASR/wiki#funasr%E7%94%A8%E6%88%B7%E6%89%8B%E5%86%8C)
 | [**Papers**](https://github.com/alibaba-damo-academy/FunASR#citations)
 | [**Runtime**](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime)
@@ -42,7 +43,7 @@ pip install --editable ./
 For more details, please ref to [installation](https://github.com/alibaba-damo-academy/FunASR/wiki)
 
 ## Usage
-For users who are new to FunASR and ModelScope, please refer to [FunASR Docs](https://alibaba-damo-academy.github.io/FunASR/index.html).
+For users who are new to FunASR and ModelScope, please refer to FunASR Docs([CN](https://alibaba-damo-academy.github.io/FunASR/cn/index.html) / [EN](https://alibaba-damo-academy.github.io/FunASR/en/index.html))
 
 ## Contact
 

BIN
docs/images/dingding.jpg


BIN
docs/images/wechat.png


+ 53 - 0
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/README.md

@@ -0,0 +1,53 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained Paraformer-large Model
+
+### Finetune
+
+- Modify finetune training related parameters in `finetune.py`
+    - <strong>output_dir:</strong> # result dir
+    - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
+    - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
+    - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
+    - <strong>max_epoch:</strong> # number of training epoch
+    - <strong>lr:</strong> # learning rate
+
+- Then you can run the pipeline to finetune with:
+```python
+    python finetune.py
+```
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+    - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+    - <strong>output_dir:</strong> # result dir
+    - <strong>ngpu:</strong> # the number of GPUs for decoding
+    - <strong>njob:</strong> # the number of jobs for each GPU
+
+- Then you can run the pipeline to infer with:
+```python
+    python infer.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.
+
+### Inference using local finetuned model
+
+- Modify inference related parameters in `infer_after_finetune.py`
+    - <strong>output_dir:</strong> # result dir
+    - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+    - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
+
+- Then you can run the pipeline to finetune with:
+```python
+    python infer_after_finetune.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/1best_recog/text.sp.cer` and `$output_dir/1best_recog/text.nosp.cer`, which includes recognition results with or without separating character (src) of each sample and the CER metric of the whole test set.

+ 40 - 0
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/RESULTS.md

@@ -0,0 +1,40 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/summary>
+- Model size: 45M
+
+# Environments
+- date: `Tue Feb 13 20:13:22 CST 2023`
+- python version: `3.7.12`
+- FunASR version: `0.1.0`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## result (paper)
+beam=20,CER tool:https://github.com/yufan-aslp/AliMeeting 
+
+|        model        | Para (M) | Data (hrs) | Eval (CER%) | Test (CER%) |
+|:-------------------:|:---------:|:---------:|:---------:| :---------:|
+| MFCCA | 45   |   917  |   16.1   | 17.5   |
+
+## result(modelscope)
+
+beam=10
+
+with separating character (src)
+
+|        model        | Para (M) | Data (hrs) | Eval_sp (CER%) | Test_sp (CER%) | 
+|:-------------------:|:---------:|:---------:|:---------:| :---------:|
+| MFCCA | 45   |   917  |   17.1   | 18.6   |
+
+without separating character (src)
+
+|        model        | Para (M) | Data (hrs) | Eval_nosp (CER%) | Test_nosp (CER%) | 
+|:-------------------:|:---------:|:---------:|:---------:| :---------:|
+| MFCCA | 45   |   917  |   16.4   | 18.0   |
+
+## 偏差
+
+Considering the differences of the CER calculation tool and decoding beam size, the results of CER are biased (<0.5%).

+ 35 - 0
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/finetune.py

@@ -0,0 +1,35 @@
+import os
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
+def modelscope_finetune(params):
+    if not os.path.exists(params.output_dir):
+        os.makedirs(params.output_dir, exist_ok=True)
+    # dataset split ["train", "validation"]
+    ds_dict = MsDataset.load(params.data_path)
+    kwargs = dict(
+        model=params.model,
+        model_revision=params.model_revision,
+        data_dir=ds_dict,
+        dataset_type=params.dataset_type,
+        work_dir=params.output_dir,
+        batch_bins=params.batch_bins,
+        max_epoch=params.max_epoch,
+        lr=params.lr)
+    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+    trainer.train()
+
+
+if __name__ == '__main__':
+    
+    params = modelscope_args(model="yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950")
+    params.output_dir = "./checkpoint"              # m模型保存路径
+    params.data_path = "./example_data/"            # 数据路径
+    params.dataset_type = "small"                   # 小数据量设置small,若数据量大于1000小时,请使用large
+    params.batch_bins = 1000                       # batch size,如果dataset_type="small",batch_bins单位为fbank特征帧数,如果dataset_type="large",batch_bins单位为毫秒,
+    params.max_epoch = 10                           # 最大训练轮数
+    params.lr = 0.0001                             # 设置学习率
+    params.model_revision = 'v2.0.0'
+    modelscope_finetune(params)

+ 103 - 0
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer.py

@@ -0,0 +1,103 @@
+import os
+import shutil
+from multiprocessing import Pool
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+import pdb;
+def modelscope_infer_core(output_dir, split_dir, njob, idx):
+    output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
+    gpu_id = (int(idx) - 1) // njob
+    if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
+        gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
+        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
+    else:
+        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
+    inference_pipline = pipeline(
+        task=Tasks.auto_speech_recognition,
+	    model='yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950',
+        model_revision='v2.0.0',
+        output_dir=output_dir_job,
+        batch_size=1,
+    )
+    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
+    inference_pipline(audio_in=audio_in)
+
+
+def modelscope_infer(params):
+    # prepare for multi-GPU decoding
+    ngpu = params["ngpu"]
+    njob = params["njob"]
+    output_dir = params["output_dir"]
+    if os.path.exists(output_dir):
+        shutil.rmtree(output_dir)
+    os.mkdir(output_dir)
+    split_dir = os.path.join(output_dir, "split")
+    os.mkdir(split_dir)
+    nj = ngpu * njob
+    wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
+    with open(wav_scp_file) as f:
+        lines = f.readlines()
+        num_lines = len(lines)
+        num_job_lines = num_lines // nj
+    start = 0
+    for i in range(nj):
+        end = start + num_job_lines
+        file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
+        with open(file, "w") as f:
+            if i == nj - 1:
+                f.writelines(lines[start:])
+            else:
+                f.writelines(lines[start:end])
+        start = end
+    p = Pool(nj)
+    for i in range(nj):
+        p.apply_async(modelscope_infer_core,
+                      args=(output_dir, split_dir, njob, str(i + 1)))
+    p.close()
+    p.join()
+
+    # combine decoding results
+    best_recog_path = os.path.join(output_dir, "1best_recog")
+    os.mkdir(best_recog_path)
+    files = ["text", "token", "score"]
+    for file in files:
+        with open(os.path.join(best_recog_path, file), "w") as f:
+            for i in range(nj):
+                job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
+                with open(job_file) as f_job:
+                    lines = f_job.readlines()
+                f.writelines(lines)
+
+    # If text exists, compute CER
+    text_in = os.path.join(params["data_dir"], "text")
+    if os.path.exists(text_in):
+        text_proc_file = os.path.join(best_recog_path, "token")
+        text_proc_file2 = os.path.join(best_recog_path, "token_nosep")
+        with open(text_proc_file, 'r') as hyp_reader:
+                with open(text_proc_file2, 'w') as hyp_writer:
+                    for line in hyp_reader:
+                        new_context = line.strip().replace("src","").replace("  "," ").replace("  "," ").strip()
+                        hyp_writer.write(new_context+'\n')
+        text_in2 = os.path.join(best_recog_path, "ref_text_nosep")
+        with open(text_in, 'r') as ref_reader:
+            with open(text_in2, 'w') as ref_writer:
+                for line in ref_reader:
+                    new_context = line.strip().replace("src","").replace("  "," ").replace("  "," ").strip()
+                    ref_writer.write(new_context+'\n')
+
+
+        compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.sp.cer"))
+        compute_wer(text_in2, text_proc_file2, os.path.join(best_recog_path, "text.nosp.cer"))
+
+
+if __name__ == "__main__":
+    params = {}
+    params["data_dir"] = "./example_data/validation"
+    params["output_dir"] = "./output_dir"
+    params["ngpu"] = 1
+    params["njob"] = 1
+    modelscope_infer(params)

+ 67 - 0
egs_modelscope/asr/mfcca/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950/infer_after_finetune.py

@@ -0,0 +1,67 @@
+import json
+import os
+import shutil
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_after_finetune(params):
+    # prepare for decoding
+    pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
+    for file_name in params["required_files"]:
+        if file_name == "configuration.json":
+            with open(os.path.join(pretrained_model_path, file_name)) as f:
+                config_dict = json.load(f)
+                config_dict["model"]["am_model_name"] = params["decoding_model_name"]
+            with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
+                json.dump(config_dict, f, indent=4, separators=(',', ': '))
+        else:
+            shutil.copy(os.path.join(pretrained_model_path, file_name),
+                        os.path.join(params["output_dir"], file_name))
+    decoding_path = os.path.join(params["output_dir"], "decode_results")
+    if os.path.exists(decoding_path):
+        shutil.rmtree(decoding_path)
+    os.mkdir(decoding_path)
+
+    # decoding
+    inference_pipeline = pipeline(
+        task=Tasks.auto_speech_recognition,
+        model=params["output_dir"],
+        output_dir=decoding_path,
+        batch_size=1
+    )
+    audio_in = os.path.join(params["data_dir"], "wav.scp")
+    inference_pipeline(audio_in=audio_in)
+
+    # computer CER if GT text is set
+    text_in = os.path.join(params["data_dir"], "text")
+    if text_in is not None:
+        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+        text_proc_file2 = os.path.join(decoding_path, "1best_recog/token_nosep")
+        with open(text_proc_file, 'r') as hyp_reader:
+                with open(text_proc_file2, 'w') as hyp_writer:
+                    for line in hyp_reader:
+                        new_context = line.strip().replace("src","").replace("  "," ").replace("  "," ").strip()
+                        hyp_writer.write(new_context+'\n')
+        text_in2 = os.path.join(decoding_path, "1best_recog/ref_text_nosep")
+        with open(text_in, 'r') as ref_reader:
+            with open(text_in2, 'w') as ref_writer:
+                for line in ref_reader:
+                    new_context = line.strip().replace("src","").replace("  "," ").replace("  "," ").strip()
+                    ref_writer.write(new_context+'\n')
+
+
+        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.sp.cer"))
+        compute_wer(text_in2, text_proc_file2, os.path.join(decoding_path, "text.nosp.cer"))
+
+if __name__ == '__main__':
+    params = {}
+    params["modelscope_model_name"] = "yufan6/speech_mfcca_asr-zh-cn-16k-alimeeting-vocab4950"
+    params["required_files"] = ["feats_stats.npz", "decoding.yaml", "configuration.json"]
+    params["output_dir"] = "./checkpoint"
+    params["data_dir"] = "./example_data/validation"
+    params["decoding_model_name"] = "valid.acc.ave.pth"
+    modelscope_infer_after_finetune(params)

+ 1 - 1
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer.py

@@ -6,7 +6,7 @@ if __name__ == '__main__':
     param_dict = dict()
     param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
 
-    audio_in = "//isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav"
+    audio_in = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav"
     output_dir = None
     batch_size = 1
 

+ 36 - 0
egs_modelscope/asr/paraformer/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/infer_aishell1_subtest_demo.py

@@ -0,0 +1,36 @@
+import os
+import tempfile
+import codecs
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+from modelscope.msdatasets import MsDataset
+
+if __name__ == '__main__':
+    param_dict = dict()
+    param_dict['hotword'] = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/hotword.txt"
+
+    output_dir = "./output"
+    batch_size = 1
+
+    # dataset split ['test']
+    ds_dict = MsDataset.load(dataset_name='speech_asr_aishell1_hotwords_testsets', namespace='speech_asr')
+    work_dir = tempfile.TemporaryDirectory().name
+    if not os.path.exists(work_dir):
+        os.makedirs(work_dir)
+    wav_file_path = os.path.join(work_dir, "wav.scp")
+    
+    with codecs.open(wav_file_path, 'w') as fin: 
+        for line in ds_dict:
+            wav = line["Audio:FILE"]
+            idx = wav.split("/")[-1].split(".")[0]
+            fin.writelines(idx + " " + wav + "\n")
+    audio_in = wav_file_path         
+
+    inference_pipeline = pipeline(
+        task=Tasks.auto_speech_recognition,
+        model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
+        output_dir=output_dir,
+        batch_size=batch_size,
+        param_dict=param_dict)
+
+    rec_result = inference_pipeline(audio_in=audio_in)

+ 6 - 0
funasr/bin/asr_inference_launch.py

@@ -228,6 +228,9 @@ def inference_launch(**kwargs):
     elif mode == "vad":
         from funasr.bin.vad_inference import inference_modelscope
         return inference_modelscope(**kwargs)
+    elif mode == "mfcca":
+        from funasr.bin.asr_inference_mfcca import inference_modelscope
+        return inference_modelscope(**kwargs)
     else:
         logging.info("Unknown decoding mode: {}".format(mode))
         return None
@@ -253,6 +256,9 @@ def inference_launch_funasr(**kwargs):
     elif mode == "vad":
         from funasr.bin.vad_inference import inference
         return inference(**kwargs)
+    elif mode == "mfcca":
+        from funasr.bin.asr_inference_mfcca import inference_modelscope
+        return inference_modelscope(**kwargs)
     else:
         logging.info("Unknown decoding mode: {}".format(mode))
         return None

+ 774 - 0
funasr/bin/asr_inference_mfcca.py

@@ -0,0 +1,774 @@
+#!/usr/bin/env python3
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+import argparse
+import logging
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.modules.beam_search.batch_beam_search import BatchBeamSearch
+from funasr.modules.beam_search.beam_search import BeamSearch
+from funasr.modules.beam_search.beam_search import Hypothesis
+from funasr.modules.scorers.ctc import CTCPrefixScorer
+from funasr.modules.scorers.length_bonus import LengthBonus
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.modules.subsampling import TooShortUttError
+from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
+from funasr.tasks.lm import LMTask
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.token_id_converter import TokenIDConverter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+import pdb
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+global_asr_language: str = 'zh-cn'
+global_sample_rate: Union[int, Dict[Any, int]] = {
+    'audio_fs': 16000,
+    'model_fs': 16000
+}
+
+class Speech2Text:
+    """Speech2Text class
+
+    Examples:
+        >>> import soundfile
+        >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
+        >>> audio, rate = soundfile.read("speech.wav")
+        >>> speech2text(audio)
+        [(text, token, token_int, hypothesis object), ...]
+
+    """
+
+    def __init__(
+            self,
+            asr_train_config: Union[Path, str] = None,
+            asr_model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
+            lm_train_config: Union[Path, str] = None,
+            lm_file: Union[Path, str] = None,
+            token_type: str = None,
+            bpemodel: str = None,
+            device: str = "cpu",
+            maxlenratio: float = 0.0,
+            minlenratio: float = 0.0,
+            batch_size: int = 1,
+            dtype: str = "float32",
+            beam_size: int = 20,
+            ctc_weight: float = 0.5,
+            lm_weight: float = 1.0,
+            ngram_weight: float = 0.9,
+            penalty: float = 0.0,
+            nbest: int = 1,
+            streaming: bool = False,
+            **kwargs,
+    ):
+        assert check_argument_types()
+
+        # 1. Build ASR model
+        scorers = {}
+        asr_model, asr_train_args = ASRTask.build_model_from_file(
+            asr_train_config, asr_model_file, cmvn_file, device
+        )
+
+        logging.info("asr_model: {}".format(asr_model))
+        logging.info("asr_train_args: {}".format(asr_train_args))
+        asr_model.to(dtype=getattr(torch, dtype)).eval()
+
+        decoder = asr_model.decoder
+
+        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+        token_list = asr_model.token_list
+        scorers.update(
+            decoder=decoder,
+            ctc=ctc,
+            length_bonus=LengthBonus(len(token_list)),
+        )
+
+        # 2. Build Language model
+        if lm_train_config is not None:
+            lm, lm_train_args = LMTask.build_model_from_file(
+                lm_train_config, lm_file, device
+            )
+            lm.to(device)
+            scorers["lm"] = lm.lm
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+
+        # 4. Build BeamSearch object
+        # transducer is not supported now
+        beam_search_transducer = None
+
+        weights = dict(
+            decoder=1.0 - ctc_weight,
+            ctc=ctc_weight,
+            lm=lm_weight,
+            ngram=ngram_weight,
+            length_bonus=penalty,
+        )
+        beam_search = BeamSearch(
+            beam_size=beam_size,
+            weights=weights,
+            scorers=scorers,
+            sos=asr_model.sos,
+            eos=asr_model.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+        )
+        #beam_search.__class__ = BatchBeamSearch
+        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
+        if token_type is None:
+            token_type = asr_train_args.token_type
+        if bpemodel is None:
+            bpemodel = asr_train_args.bpemodel
+        
+        if token_type is None:
+            tokenizer = None
+        elif token_type == "bpe":
+            if bpemodel is not None:
+                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
+            else:
+                tokenizer = None
+        else:
+            tokenizer = build_tokenizer(token_type=token_type)
+        converter = TokenIDConverter(token_list=token_list)
+        logging.info(f"Text tokenizer: {tokenizer}")
+
+        self.asr_model = asr_model
+        self.asr_train_args = asr_train_args
+        self.converter = converter
+        self.tokenizer = tokenizer
+        self.beam_search = beam_search
+        self.beam_search_transducer = beam_search_transducer
+        self.maxlenratio = maxlenratio
+        self.minlenratio = minlenratio
+        self.device = device
+        self.dtype = dtype
+        self.nbest = nbest
+
+    @torch.no_grad()
+    def __call__(
+            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
+    ) -> List[
+        Tuple[
+            Optional[str],
+            List[str],
+            List[int],
+            Union[Hypothesis],
+        ]
+    ]:
+        """Inference
+
+        Args:
+            speech: Input speech data
+        Returns:
+            text, token, token_int, hyp
+
+        """
+        assert check_argument_types()
+        # Input as audio signal
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+
+
+        #speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        speech = speech.to(getattr(torch, self.dtype))
+        # lenghts: (1,)
+        lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+        batch = {"speech": speech, "speech_lengths": lengths}
+
+        # a. To device
+        batch = to_device(batch, device=self.device)
+        
+        # b. Forward Encoder
+        enc, _ = self.asr_model.encode(**batch)
+        
+        assert len(enc) == 1, len(enc)
+
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
+        )
+        
+        nbest_hyps = nbest_hyps[: self.nbest]
+        
+        results = []
+        for hyp in nbest_hyps:
+            assert isinstance(hyp, (Hypothesis)), type(hyp)
+
+            # remove sos/eos and get results
+            last_pos = -1
+            if isinstance(hyp.yseq, list):
+                token_int = hyp.yseq[1:last_pos]
+            else:
+                token_int = hyp.yseq[1:last_pos].tolist()
+
+            # remove blank symbol id, which is assumed to be 0
+            token_int = list(filter(lambda x: x != 0, token_int))
+
+            # Change integer-ids to tokens
+            token = self.converter.ids2tokens(token_int)
+
+            if self.tokenizer is not None:
+                text = self.tokenizer.tokens2text(token)
+            else:
+                text = None
+            results.append((text, token, token_int, hyp))
+
+        assert check_return_type(results)
+        return results
+
+
+# def inference(
+#         maxlenratio: float,
+#         minlenratio: float,
+#         batch_size: int,
+#         beam_size: int,
+#         ngpu: int,
+#         ctc_weight: float,
+#         lm_weight: float,
+#         penalty: float,
+#         log_level: Union[int, str],
+#         data_path_and_name_and_type,
+#         asr_train_config: Optional[str],
+#         asr_model_file: Optional[str],
+#         cmvn_file: Optional[str] = None,
+#         lm_train_config: Optional[str] = None,
+#         lm_file: Optional[str] = None,
+#         token_type: Optional[str] = None,
+#         key_file: Optional[str] = None,
+#         word_lm_train_config: Optional[str] = None,
+#         bpemodel: Optional[str] = None,
+#         allow_variable_data_keys: bool = False,
+#         streaming: bool = False,
+#         output_dir: Optional[str] = None,
+#         dtype: str = "float32",
+#         seed: int = 0,
+#         ngram_weight: float = 0.9,
+#         nbest: int = 1,
+#         num_workers: int = 1,
+#         **kwargs,
+# ):
+#     assert check_argument_types()
+#     if batch_size > 1:
+#         raise NotImplementedError("batch decoding is not implemented")
+#     if word_lm_train_config is not None:
+#         raise NotImplementedError("Word LM is not implemented")
+#     if ngpu > 1:
+#         raise NotImplementedError("only single GPU decoding is supported")
+#
+#     logging.basicConfig(
+#         level=log_level,
+#         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+#     )
+#
+#     if ngpu >= 1 and torch.cuda.is_available():
+#         device = "cuda"
+#     else:
+#         device = "cpu"
+#
+#     # 1. Set random-seed
+#     set_all_random_seed(seed)
+#
+#     # 2. Build speech2text
+#     speech2text_kwargs = dict(
+#         asr_train_config=asr_train_config,
+#         asr_model_file=asr_model_file,
+#         cmvn_file=cmvn_file,
+#         lm_train_config=lm_train_config,
+#         lm_file=lm_file,
+#         token_type=token_type,
+#         bpemodel=bpemodel,
+#         device=device,
+#         maxlenratio=maxlenratio,
+#         minlenratio=minlenratio,
+#         dtype=dtype,
+#         beam_size=beam_size,
+#         ctc_weight=ctc_weight,
+#         lm_weight=lm_weight,
+#         ngram_weight=ngram_weight,
+#         penalty=penalty,
+#         nbest=nbest,
+#         streaming=streaming,
+#     )
+#     logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+#     speech2text = Speech2Text(**speech2text_kwargs)
+#
+#     # 3. Build data-iterator
+#     loader = ASRTask.build_streaming_iterator(
+#         data_path_and_name_and_type,
+#         dtype=dtype,
+#         batch_size=batch_size,
+#         key_file=key_file,
+#         num_workers=num_workers,
+#         preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+#         collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+#         allow_variable_data_keys=allow_variable_data_keys,
+#         inference=True,
+#     )
+#
+#     finish_count = 0
+#     file_count = 1
+#     # 7 .Start for-loop
+#     # FIXME(kamo): The output format should be discussed about
+#     asr_result_list = []
+#     if output_dir is not None:
+#         writer = DatadirWriter(output_dir)
+#     else:
+#         writer = None
+#
+#     for keys, batch in loader:
+#         assert isinstance(batch, dict), type(batch)
+#         assert all(isinstance(s, str) for s in keys), keys
+#         _bs = len(next(iter(batch.values())))
+#         assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+#         #batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+#
+#         # N-best list of (text, token, token_int, hyp_object)
+#         try:
+#             results = speech2text(**batch)
+#         except TooShortUttError as e:
+#             logging.warning(f"Utterance {keys} {e}")
+#             hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+#             results = [[" ", ["<space>"], [2], hyp]] * nbest
+#
+#         # Only supporting batch_size==1
+#         key = keys[0]
+#         for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+#             # Create a directory: outdir/{n}best_recog
+#             if writer is not None:
+#                 ibest_writer = writer[f"{n}best_recog"]
+#
+#                 # Write the result to each file
+#                 ibest_writer["token"][key] = " ".join(token)
+#                 ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+#                 ibest_writer["score"][key] = str(hyp.score)
+#
+#             if text is not None:
+#                 text_postprocessed = postprocess_utils.sentence_postprocess(token)
+#                 item = {'key': key, 'value': text_postprocessed}
+#                 asr_result_list.append(item)
+#                 finish_count += 1
+#                 asr_utils.print_progress(finish_count / file_count)
+#                 if writer is not None:
+#                     ibest_writer["text"][key] = text
+#     return asr_result_list
+
+def inference(
+        maxlenratio: float,
+        minlenratio: float,
+        batch_size: int,
+        beam_size: int,
+        ngpu: int,
+        ctc_weight: float,
+        lm_weight: float,
+        penalty: float,
+        log_level: Union[int, str],
+        data_path_and_name_and_type,
+        asr_train_config: Optional[str],
+        asr_model_file: Optional[str],
+        cmvn_file: Optional[str] = None,
+        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+        lm_train_config: Optional[str] = None,
+        lm_file: Optional[str] = None,
+        token_type: Optional[str] = None,
+        key_file: Optional[str] = None,
+        word_lm_train_config: Optional[str] = None,
+        bpemodel: Optional[str] = None,
+        allow_variable_data_keys: bool = False,
+        streaming: bool = False,
+        output_dir: Optional[str] = None,
+        dtype: str = "float32",
+        seed: int = 0,
+        ngram_weight: float = 0.9,
+        nbest: int = 1,
+        num_workers: int = 1,
+        **kwargs,
+):
+    inference_pipeline = inference_modelscope(
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        batch_size=batch_size,
+        beam_size=beam_size,
+        ngpu=ngpu,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        penalty=penalty,
+        log_level=log_level,
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        raw_inputs=raw_inputs,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        key_file=key_file,
+        word_lm_train_config=word_lm_train_config,
+        bpemodel=bpemodel,
+        allow_variable_data_keys=allow_variable_data_keys,
+        streaming=streaming,
+        output_dir=output_dir,
+        dtype=dtype,
+        seed=seed,
+        ngram_weight=ngram_weight,
+        nbest=nbest,
+        num_workers=num_workers,
+        **kwargs,
+    )
+    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+def inference_modelscope(
+    maxlenratio: float,
+    minlenratio: float,
+    batch_size: int,
+    beam_size: int,
+    ngpu: int,
+    ctc_weight: float,
+    lm_weight: float,
+    penalty: float,
+    log_level: Union[int, str],
+    # data_path_and_name_and_type,
+    asr_train_config: Optional[str],
+    asr_model_file: Optional[str],
+    cmvn_file: Optional[str] = None,
+    lm_train_config: Optional[str] = None,
+    lm_file: Optional[str] = None,
+    token_type: Optional[str] = None,
+    key_file: Optional[str] = None,
+    word_lm_train_config: Optional[str] = None,
+    bpemodel: Optional[str] = None,
+    allow_variable_data_keys: bool = False,
+    streaming: bool = False,
+    output_dir: Optional[str] = None,
+    dtype: str = "float32",
+    seed: int = 0,
+    ngram_weight: float = 0.9,
+    nbest: int = 1,
+    num_workers: int = 1,
+    param_dict: dict = None,
+    **kwargs,
+):
+    assert check_argument_types()
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
+    if word_lm_train_config is not None:
+        raise NotImplementedError("Word LM is not implemented")
+    if ngpu > 1:
+        raise NotImplementedError("only single GPU decoding is supported")
+    
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+    
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+    
+    # 2. Build speech2text
+    speech2text_kwargs = dict(
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        bpemodel=bpemodel,
+        device=device,
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        dtype=dtype,
+        beam_size=beam_size,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        ngram_weight=ngram_weight,
+        penalty=penalty,
+        nbest=nbest,
+        streaming=streaming,
+    )
+    logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
+    speech2text = Speech2Text(**speech2text_kwargs)
+    
+    def _forward(data_path_and_name_and_type,
+                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+                 output_dir_v2: Optional[str] = None,
+                 fs: dict = None,
+                 param_dict: dict = None,
+                 ):
+        # 3. Build data-iterator
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, torch.Tensor):
+                raw_inputs = raw_inputs.numpy()
+            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+        loader = ASRTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
+        
+        finish_count = 0
+        file_count = 1
+        # 7 .Start for-loop
+        # FIXME(kamo): The output format should be discussed about
+        asr_result_list = []
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if output_path is not None:
+            writer = DatadirWriter(output_path)
+        else:
+            writer = None
+        
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+            
+            # N-best list of (text, token, token_int, hyp_object)
+            try:
+                results = speech2text(**batch)
+            except TooShortUttError as e:
+                logging.warning(f"Utterance {keys} {e}")
+                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+                results = [[" ", ["<space>"], [2], hyp]] * nbest
+            
+            # Only supporting batch_size==1
+            key = keys[0]
+            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+                # Create a directory: outdir/{n}best_recog
+                if writer is not None:
+                    ibest_writer = writer[f"{n}best_recog"]
+                    
+                    # Write the result to each file
+                    ibest_writer["token"][key] = " ".join(token)
+                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["score"][key] = str(hyp.score)
+                
+                if text is not None:
+                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                    item = {'key': key, 'value': text_postprocessed}
+                    asr_result_list.append(item)
+                    finish_count += 1
+                    asr_utils.print_progress(finish_count / file_count)
+                    if writer is not None:
+                        ibest_writer["text"][key] = text
+        return asr_result_list
+    
+    return _forward
+
+def set_parameters(language: str = None,
+                   sample_rate: Union[int, Dict[Any, int]] = None):
+    if language is not None:
+        global global_asr_language
+        global_asr_language = language
+    if sample_rate is not None:
+        global global_sample_rate
+        global_sample_rate = sample_rate
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="ASR Decoding",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Note(kamo): Use '_' instead of '-' as separator.
+    # '-' is confusing if written in yaml.
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, required=True)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument(
+        "--gpuid_list",
+        type=str,
+        default="",
+        help="The visible gpus",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        required=False,
+        action="append",
+    )
+    group.add_argument("--raw_inputs", type=list, default=None)
+    # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+    group.add_argument("--key_file", type=str_or_none)
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument(
+        "--asr_train_config",
+        type=str,
+        help="ASR training configuration",
+    )
+    group.add_argument(
+        "--asr_model_file",
+        type=str,
+        help="ASR model parameter file",
+    )
+    group.add_argument(
+        "--cmvn_file",
+        type=str,
+        help="Global cmvn file",
+    )
+    group.add_argument(
+        "--lm_train_config",
+        type=str,
+        help="LM training configuration",
+    )
+    group.add_argument(
+        "--lm_file",
+        type=str,
+        help="LM parameter file",
+    )
+    group.add_argument(
+        "--word_lm_train_config",
+        type=str,
+        help="Word LM training configuration",
+    )
+    group.add_argument(
+        "--word_lm_file",
+        type=str,
+        help="Word LM parameter file",
+    )
+    group.add_argument(
+        "--ngram_file",
+        type=str,
+        help="N-gram parameter file",
+    )
+    group.add_argument(
+        "--model_tag",
+        type=str,
+        help="Pretrained model tag. If specify this option, *_train_config and "
+             "*_file will be overwritten",
+    )
+
+    group = parser.add_argument_group("Beam-search related")
+    group.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
+    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+    group.add_argument(
+        "--maxlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain max output length. "
+             "If maxlenratio=0.0 (default), it uses a end-detect "
+             "function "
+             "to automatically find maximum hypothesis lengths."
+             "If maxlenratio<0.0, its absolute value is interpreted"
+             "as a constant max output length",
+    )
+    group.add_argument(
+        "--minlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain min output length",
+    )
+    group.add_argument(
+        "--ctc_weight",
+        type=float,
+        default=0.5,
+        help="CTC weight in joint decoding",
+    )
+    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+    group.add_argument("--streaming", type=str2bool, default=False)
+
+    group = parser.add_argument_group("Text converter related")
+    group.add_argument(
+        "--token_type",
+        type=str_or_none,
+        default=None,
+        choices=["char", "bpe", None],
+        help="The token type for ASR model. "
+             "If not given, refers from the training args",
+    )
+    group.add_argument(
+        "--bpemodel",
+        type=str_or_none,
+        default=None,
+        help="The model path of sentencepiece. "
+             "If not given, refers from the training args",
+    )
+
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    kwargs.pop("config", None)
+    inference(**kwargs)
+
+
+if __name__ == "__main__":
+    main()

+ 2 - 0
funasr/bin/build_trainer.py

@@ -27,6 +27,8 @@ def parse_args(mode):
         from funasr.tasks.asr import ASRTaskParaformer as ASRTask
     elif mode == "uniasr":
         from funasr.tasks.asr import ASRTaskUniASR as ASRTask
+    elif mode == "mfcca":
+        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask   
     else:
         raise ValueError("Unknown mode: {}".format(mode))
     parser = ASRTask.get_parser()

+ 2 - 2
funasr/datasets/iterable_dataset.py

@@ -224,7 +224,7 @@ class IterableESPnetDataset(IterableDataset):
                 name = self.path_name_type_list[i][1]
                 _type = self.path_name_type_list[i][2]
                 if _type == "sound":
-                    audio_type = os.path.basename(value).split(".")[1].lower()
+                    audio_type = os.path.basename(value).split(".")[-1].lower()
                     if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                         raise NotImplementedError(
                             f'Not supported audio type: {audio_type}')
@@ -326,7 +326,7 @@ class IterableESPnetDataset(IterableDataset):
                 # 2.a. Load data streamingly
                 for value, (path, name, _type) in zip(values, self.path_name_type_list):
                     if _type == "sound":
-                        audio_type = os.path.basename(value).split(".")[1].lower()
+                        audio_type = os.path.basename(value).split(".")[-1].lower()
                         if audio_type not in SUPPORT_AUDIO_TYPE_SETS:
                             raise NotImplementedError(
                                 f'Not supported audio type: {audio_type}')

+ 9 - 19
funasr/export/README.md

@@ -11,33 +11,23 @@ The installation is the same as [funasr](../../README.md)
 
 ## Export onnx format model
 Export model from modelscope
-```python
-from funasr.export.export_model import ASRModelExportParaformer
-
-output_dir = "../export"  # onnx/torchscripts model save path
-export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
-export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+```shell
+python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
 ```
-
-
 Export model from local path
-```python
-export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+```shell
+python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
 ```
 
 ## Export torchscripts format model
 Export model from modelscope
-```python
-from funasr.export.export_model import ASRModelExportParaformer
-
-output_dir = "../export"  # onnx/torchscripts model save path
-export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=False)
-export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+```shell
+python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false
 ```
 
-Export model from local path
-```python
 
-export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+Export model from local path
+```shell
+python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false
 ```
 

+ 13 - 5
funasr/export/export_model.py

@@ -24,7 +24,7 @@ class ASRModelExportParaformer:
             feats_dim=560,
             onnx=False,
         )
-        logging.info("output dir: {}".format(self.cache_dir))
+        print("output dir: {}".format(self.cache_dir))
         self.onnx = onnx
         
 
@@ -50,7 +50,7 @@ class ASRModelExportParaformer:
         else:
             self._export_torchscripts(model, verbose, export_dir)
 
-        logging.info("output dir: {}".format(export_dir))
+        print("output dir: {}".format(export_dir))
 
 
     def _export_torchscripts(self, model, verbose, path, enc_size=None):
@@ -117,7 +117,15 @@ class ASRModelExportParaformer:
         )
 
 if __name__ == '__main__':
-    output_dir = "../export"
-    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=True)
-    export_model.export('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+    import sys
+    
+    model_path = sys.argv[1]
+    output_dir = sys.argv[2]
+    onnx = sys.argv[3]
+    onnx = onnx.lower()
+    onnx = onnx == 'true'
+    # model_path = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'
+    # output_dir = "../export"
+    export_model = ASRModelExportParaformer(cache_dir=output_dir, onnx=onnx)
+    export_model.export(model_path)
     # export_model.export('/root/cache/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')

+ 1 - 1
funasr/export/models/e2e_asr_paraformer.py

@@ -59,7 +59,7 @@ class Paraformer(nn.Module):
         enc, enc_len = self.encoder(**batch)
         mask = self.make_pad_mask(enc_len)[:, None, :]
         pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
-        pre_token_length = pre_token_length.round().long()
+        pre_token_length = pre_token_length.round().type(torch.int32)
 
         decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
         decoder_out = torch.log_softmax(decoder_out, dim=-1)

+ 0 - 50
funasr/export/models/predictor/cif.py

@@ -116,53 +116,3 @@ def cif(hidden, alphas, threshold: float):
 		pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
 		list_ls.append(torch.cat([l, pad_l], 0))
 	return torch.stack(list_ls, 0), fires
-
-
-def CifPredictorV2_test():
-	x = torch.rand([2, 21, 2])
-	x_len = torch.IntTensor([6, 21])
-	
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	
-	predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
-	# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
-	predictor_scripts.save('test.pt')
-	loaded = torch.jit.load('test.pt')
-	cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
-	# print(cif_output)
-	print(predictor_scripts.code)
-	# predictor = CifPredictorV2(2, 1, 1)
-	# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
-	print(cif_output)
-
-
-def CifPredictorV2_export_test():
-	x = torch.rand([2, 21, 2])
-	x_len = torch.IntTensor([6, 21])
-	
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	
-	# predictor_scripts = torch.jit.script(CifPredictorV2(2, 1, 1))
-	# cif_output, cif_length, alphas, cif_peak = predictor_scripts(x, mask=mask[:, None, :])
-	predictor = CifPredictorV2(2, 1, 1)
-	predictor_trace = torch.jit.trace(predictor, (x, mask[:, None, :]))
-	predictor_trace.save('test_trace.pt')
-	loaded = torch.jit.load('test_trace.pt')
-	
-	x = torch.rand([3, 30, 2])
-	x_len = torch.IntTensor([6, 20, 30])
-	mask = sequence_mask(x_len, maxlen=x.size(1), dtype=x.dtype)
-	x = x * mask[:, :, None]
-	cif_output, cif_length, alphas, cif_peak = loaded(x, mask=mask[:, None, :])
-	print(cif_output)
-	# print(predictor_trace.code)
-	# predictor = CifPredictorV2(2, 1, 1)
-	# cif_output, cif_length, alphas, cif_peak = predictor(x, mask=mask[:, None, :])
-	# print(cif_output)
-
-
-if __name__ == '__main__':
-	# CifPredictorV2_test()
-	CifPredictorV2_export_test()

+ 322 - 0
funasr/models/e2e_asr_mfcca.py

@@ -0,0 +1,322 @@
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+from typeguard import check_argument_types
+
+from funasr.modules.e2e_asr_common import ErrorCalculator
+from funasr.modules.nets_utils import th_accuracy
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.losses.label_smoothing_loss import (
+    LabelSmoothingLoss,  # noqa: H301
+)
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+import pdb
+import random
+import math
+class MFCCA(AbsESPnetModel):
+    """CTC-attention hybrid Encoder-Decoder model"""
+
+    def __init__(
+        self,
+        vocab_size: int,
+        token_list: Union[Tuple[str, ...], List[str]],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
+        preencoder: Optional[AbsPreEncoder],
+        encoder: AbsEncoder,
+        decoder: AbsDecoder,
+        ctc: CTC,
+        rnnt_decoder: None,
+        ctc_weight: float = 0.5,
+        ignore_id: int = -1,
+        lsm_weight: float = 0.0,
+        mask_ratio: float = 0.0,
+        length_normalized_loss: bool = False,
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+    ):
+        assert check_argument_types()
+        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+        assert rnnt_decoder is None, "Not implemented"
+
+        super().__init__()
+        # note that eos is the same as sos (equivalent ID)
+        self.sos = vocab_size - 1
+        self.eos = vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.ctc_weight = ctc_weight
+        self.token_list = token_list.copy()
+        
+        self.mask_ratio = mask_ratio
+
+        
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+        self.preencoder = preencoder
+        self.encoder = encoder
+        # we set self.decoder = None in the CTC mode since
+        # self.decoder parameters were never used and PyTorch complained
+        # and threw an Exception in the multi-GPU experiment.
+        # thanks Jeff Farris for pointing out the issue.
+        if ctc_weight == 1.0:
+            self.decoder = None
+        else:
+            self.decoder = decoder
+        if ctc_weight == 0.0:
+            self.ctc = None
+        else:
+            self.ctc = ctc
+        self.rnnt_decoder = rnnt_decoder
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        if report_cer or report_wer:
+            self.error_calculator = ErrorCalculator(
+                token_list, sym_space, sym_blank, report_cer, report_wer
+            )
+        else:
+            self.error_calculator = None
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Frontend + Encoder + Decoder + Calc loss
+
+        Args:
+            speech: (Batch, Length, ...)
+            speech_lengths: (Batch, )
+            text: (Batch, Length)
+            text_lengths: (Batch,)
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        # Check that batch_size is unified
+        assert (
+            speech.shape[0]
+            == speech_lengths.shape[0]
+            == text.shape[0]
+            == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+        #pdb.set_trace()
+        if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
+            rate_num = random.random()
+            #rate_num = 0.1
+            if(rate_num<=self.mask_ratio):
+                retain_channel = math.ceil(random.random() *8)
+                if(retain_channel>1):
+                    speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
+                else:
+                    speech = speech[:,:,torch.randperm(8)[0]]
+        #pdb.set_trace()
+        batch_size = speech.shape[0]
+        # for data-parallel
+        text = text[:, : text_lengths.max()]
+
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        # 2a. Attention-decoder branch
+        if self.ctc_weight == 1.0:
+            loss_att, acc_att, cer_att, wer_att = None, None, None, None
+        else:
+            loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+
+        # 2b. CTC branch
+        if self.ctc_weight == 0.0:
+            loss_ctc, cer_ctc = None, None
+        else:
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+
+        # 2c. RNN-T branch
+        if self.rnnt_decoder is not None:
+            _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths)
+
+        if self.ctc_weight == 0.0:
+            loss = loss_att
+        elif self.ctc_weight == 1.0:
+            loss = loss_ctc
+        else:
+            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+
+        stats = dict(
+            loss=loss.detach(),
+            loss_att=loss_att.detach() if loss_att is not None else None,
+            loss_ctc=loss_ctc.detach() if loss_ctc is not None else None,
+            acc=acc_att,
+            cer=cer_att,
+            wer=wer_att,
+            cer_ctc=cer_ctc,
+        )
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def collect_feats(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+    ) -> Dict[str, torch.Tensor]:
+        feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
+        return {"feats": feats, "feats_lengths": feats_lengths}
+
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+        Args:
+            speech: (Batch, Length, ...)
+            speech_lengths: (Batch, )
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+        #pdb.set_trace()
+        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
+
+        assert encoder_out.size(0) == speech.size(0), (
+            encoder_out.size(),
+            speech.size(0),
+        )
+        if(encoder_out.dim()==4):
+            assert encoder_out.size(2) <= encoder_out_lens.max(), (
+                encoder_out.size(),
+                encoder_out_lens.max(),
+            )
+        else:
+            assert encoder_out.size(1) <= encoder_out_lens.max(), (
+                encoder_out.size(),
+                encoder_out_lens.max(),
+            )
+
+        return encoder_out, encoder_out_lens
+
+    def _extract_feats(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        assert speech_lengths.dim() == 1, speech_lengths.shape
+        # for data-parallel
+        speech = speech[:, : speech_lengths.max()]
+        if self.frontend is not None:
+            # Frontend
+            #  e.g. STFT and Feature extract
+            #       data_loader may send time-domain signal in this case
+            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+            feats, feats_lengths, channel_size = self.frontend(speech, speech_lengths)
+        else:
+            # No frontend and no feature extract
+            feats, feats_lengths = speech, speech_lengths
+            channel_size = 1
+        return feats, feats_lengths, channel_size
+
+    def _calc_att_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+        ys_in_lens = ys_pad_lens + 1
+
+        # 1. Forward decoder
+        decoder_out, _ = self.decoder(
+            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+        )
+
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_out_pad)
+        acc_att = th_accuracy(
+            decoder_out.view(-1, self.vocab_size),
+            ys_out_pad,
+            ignore_label=self.ignore_id,
+        )
+
+        # Compute cer/wer using attention-decoder
+        if self.training or self.error_calculator is None:
+            cer_att, wer_att = None, None
+        else:
+            ys_hat = decoder_out.argmax(dim=-1)
+            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+        return loss_att, acc_att, cer_att, wer_att
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        if(encoder_out.dim()==4):
+            encoder_out = encoder_out.mean(1)
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
+
+    def _calc_rnnt_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        raise NotImplementedError

+ 270 - 0
funasr/models/encoder/encoder_layer_mfcca.py

@@ -0,0 +1,270 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
+#                Northwestern Polytechnical University (Pengcheng Guo)
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Encoder self-attention layer definition."""
+
+import torch
+
+from torch import nn
+
+from funasr.modules.layer_norm import LayerNorm
+from torch.autograd import Variable
+
+
+
+class Encoder_Conformer_Layer(nn.Module):
+    """Encoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+            can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        conv_module (torch.nn.Module): Convolution module instance.
+            `ConvlutionModule` instance can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+    """
+
+    def __init__(
+        self,
+        size,
+        self_attn,
+        feed_forward,
+        feed_forward_macaron,
+        conv_module,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        cca_pos=0,
+    ):
+        """Construct an Encoder_Conformer_Layer object."""
+        super(Encoder_Conformer_Layer, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.feed_forward_macaron = feed_forward_macaron
+        self.conv_module = conv_module
+        self.norm_ff = LayerNorm(size)  # for the FNN module
+        self.norm_mha = LayerNorm(size)  # for the MHA module
+        if feed_forward_macaron is not None:
+            self.norm_ff_macaron = LayerNorm(size)
+            self.ff_scale = 0.5
+        else:
+            self.ff_scale = 1.0
+        if self.conv_module is not None:
+            self.norm_conv = LayerNorm(size)  # for the CNN module
+            self.norm_final = LayerNorm(size)  # for the final output of the block
+        self.dropout = nn.Dropout(dropout_rate)
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        self.cca_pos = cca_pos
+
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+
+    def forward(self, x_input, mask, cache=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+                - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+                - w/o pos emb: Tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        if isinstance(x_input, tuple):
+            x, pos_emb = x_input[0], x_input[1]
+        else:
+            x, pos_emb = x_input, None
+        # whether to use macaron style
+        if self.feed_forward_macaron is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm_ff_macaron(x)
+            x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
+            if not self.normalize_before:
+                x = self.norm_ff_macaron(x)
+
+        # multi-headed self-attention module
+        residual = x
+        if self.normalize_before:
+            x = self.norm_mha(x)
+
+
+        if cache is None:
+            x_q = x
+        else:
+            assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+            x_q = x[:, -1:, :]
+            residual = residual[:, -1:, :]
+            mask = None if mask is None else mask[:, -1:, :]
+
+        if self.cca_pos<2:
+            if pos_emb is not None:
+                x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+            else:
+                x_att = self.self_attn(x_q, x, x, mask)
+        else:    
+            x_att = self.self_attn(x_q, x, x, mask)
+
+        if self.concat_after:
+            x_concat = torch.cat((x, x_att), dim=-1)
+            x = residual + self.concat_linear(x_concat)
+        else:
+            x = residual + self.dropout(x_att)
+        if not self.normalize_before:
+            x = self.norm_mha(x)
+
+        # convolution module
+        if self.conv_module is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm_conv(x)
+            x = residual + self.dropout(self.conv_module(x))
+            if not self.normalize_before:
+                x = self.norm_conv(x)
+
+        # feed forward module
+        residual = x
+        if self.normalize_before:
+            x = self.norm_ff(x)
+        x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm_ff(x)
+
+        if self.conv_module is not None:
+            x = self.norm_final(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        if pos_emb is not None:
+            return (x, pos_emb), mask
+
+        return x, mask
+
+
+
+
+class EncoderLayer(nn.Module):
+    """Encoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+            can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        conv_module (torch.nn.Module): Convolution module instance.
+            `ConvlutionModule` instance can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+    """
+
+    def __init__(
+        self,
+        size,
+        self_attn_cros_channel,
+        self_attn_conformer,
+        feed_forward_csa,
+        feed_forward_macaron_csa,
+        conv_module_csa,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayer, self).__init__()
+
+        self.encoder_cros_channel_atten = self_attn_cros_channel
+        self.encoder_csa = Encoder_Conformer_Layer(
+                size,
+                self_attn_conformer,
+                feed_forward_csa,
+                feed_forward_macaron_csa,
+                conv_module_csa,
+                dropout_rate,
+                normalize_before,
+                concat_after,
+                cca_pos=0)
+        self.norm_mha = LayerNorm(size)  # for the MHA module
+        self.dropout = nn.Dropout(dropout_rate)
+
+
+    def forward(self, x_input, mask, channel_size, cache=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+                - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+                - w/o pos emb: Tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        if isinstance(x_input, tuple):
+            x, pos_emb = x_input[0], x_input[1]
+        else:
+            x, pos_emb = x_input, None
+        residual = x
+        x = self.norm_mha(x)
+        t_leng = x.size(1)
+        d_dim = x.size(2)
+        x_new = x.reshape(-1,channel_size,t_leng,d_dim).transpose(1,2) # x_new B*T * C * D
+        x_k_v = x_new.new(x_new.size(0),x_new.size(1),5,x_new.size(2),x_new.size(3))
+        pad_before = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
+        pad_after = Variable(torch.zeros(x_new.size(0),2,x_new.size(2),x_new.size(3))).type(x_new.type())
+        x_pad = torch.cat([pad_before,x_new, pad_after], 1)
+        x_k_v[:,:,0,:,:]=x_pad[:,0:-4,:,:]
+        x_k_v[:,:,1,:,:]=x_pad[:,1:-3,:,:]
+        x_k_v[:,:,2,:,:]=x_pad[:,2:-2,:,:]
+        x_k_v[:,:,3,:,:]=x_pad[:,3:-1,:,:]
+        x_k_v[:,:,4,:,:]=x_pad[:,4:,:,:]
+        x_new = x_new.reshape(-1,channel_size,d_dim)
+        x_k_v = x_k_v.reshape(-1,5*channel_size,d_dim)
+        x_att = self.encoder_cros_channel_atten(x_new, x_k_v, x_k_v, None)
+        x_att = x_att.reshape(-1,t_leng,channel_size,d_dim).transpose(1,2).reshape(-1,t_leng,d_dim)
+        x = residual + self.dropout(x_att)
+        if pos_emb is not None:
+            x_input =  (x, pos_emb)
+        else:
+            x_input = x
+        x_input, mask = self.encoder_csa(x_input, mask)
+
+
+        return x_input, mask , channel_size

+ 450 - 0
funasr/models/encoder/mfcca_encoder.py

@@ -0,0 +1,450 @@
+from typing import Optional
+from typing import Tuple
+
+import logging
+import torch
+from torch import nn
+
+from typeguard import check_argument_types
+
+from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
+from funasr.modules.nets_utils import get_activation
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.attention import (
+    MultiHeadedAttention,  # noqa: H301
+    RelPositionMultiHeadedAttention,  # noqa: H301
+    LegacyRelPositionMultiHeadedAttention,  # noqa: H301
+)
+from funasr.modules.embedding import (
+    PositionalEncoding,  # noqa: H301
+    ScaledPositionalEncoding,  # noqa: H301
+    RelPositionalEncoding,  # noqa: H301
+    LegacyRelPositionalEncoding,  # noqa: H301
+)
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.multi_layer_conv import Conv1dLinear
+from funasr.modules.multi_layer_conv import MultiLayeredConv1d
+from funasr.modules.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.modules.repeat import repeat
+from funasr.modules.subsampling import Conv2dSubsampling
+from funasr.modules.subsampling import Conv2dSubsampling2
+from funasr.modules.subsampling import Conv2dSubsampling6
+from funasr.modules.subsampling import Conv2dSubsampling8
+from funasr.modules.subsampling import TooShortUttError
+from funasr.modules.subsampling import check_short_utt
+from funasr.models.encoder.abs_encoder import AbsEncoder
+import pdb
+import math
+
+class ConvolutionModule(nn.Module):
+    """ConvolutionModule in Conformer model.
+
+    Args:
+        channels (int): The number of channels of conv layers.
+        kernel_size (int): Kernerl size of conv layers.
+
+    """
+
+    def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
+        """Construct an ConvolutionModule object."""
+        super(ConvolutionModule, self).__init__()
+        # kernerl_size should be a odd number for 'SAME' padding
+        assert (kernel_size - 1) % 2 == 0
+
+        self.pointwise_conv1 = nn.Conv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+        self.depthwise_conv = nn.Conv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=(kernel_size - 1) // 2,
+            groups=channels,
+            bias=bias,
+        )
+        self.norm = nn.BatchNorm1d(channels)
+        self.pointwise_conv2 = nn.Conv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+        self.activation = activation
+
+    def forward(self, x):
+        """Compute convolution module.
+
+        Args:
+            x (torch.Tensor): Input tensor (#batch, time, channels).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, channels).
+
+        """
+        # exchange the temporal dimension and the feature dimension
+        x = x.transpose(1, 2)
+
+        # GLU mechanism
+        x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
+        x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
+
+        # 1D Depthwise Conv
+        x = self.depthwise_conv(x)
+        x = self.activation(self.norm(x))
+
+        x = self.pointwise_conv2(x)
+
+        return x.transpose(1, 2)
+
+
+
+class MFCCAEncoder(AbsEncoder):
+    """Conformer encoder module.
+
+    Args:
+        input_size (int): Input dimension.
+        output_size (int): Dimention of attention.
+        attention_heads (int): The number of heads of multi head attention.
+        linear_units (int): The number of units of position-wise feed forward.
+        num_blocks (int): The number of decoder blocks.
+        dropout_rate (float): Dropout rate.
+        attention_dropout_rate (float): Dropout rate in attention.
+        positional_dropout_rate (float): Dropout rate after adding positional encoding.
+        input_layer (Union[str, torch.nn.Module]): Input layer type.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            If True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            If False, no additional linear will be applied. i.e. x -> x + att(x)
+        positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
+        positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
+        rel_pos_type (str): Whether to use the latest relative positional encoding or
+            the legacy one. The legacy relative positional encoding will be deprecated
+            in the future. More Details can be found in
+            https://github.com/espnet/espnet/pull/2816.
+        encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
+        encoder_attn_layer_type (str): Encoder attention layer type.
+        activation_type (str): Encoder activation function type.
+        macaron_style (bool): Whether to use macaron style for positionwise layer.
+        use_cnn_module (bool): Whether to use convolution module.
+        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+        cnn_module_kernel (int): Kernerl size of convolution module.
+        padding_idx (int): Padding idx for input_layer=embed.
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: str = "conv2d",
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 3,
+        macaron_style: bool = False,
+        rel_pos_type: str = "legacy",
+        pos_enc_layer_type: str = "rel_pos",
+        selfattention_layer_type: str = "rel_selfattn",
+        activation_type: str = "swish",
+        use_cnn_module: bool = True,
+        zero_triu: bool = False,
+        cnn_module_kernel: int = 31,
+        padding_idx: int = -1,
+    ):
+        assert check_argument_types()
+        super().__init__()
+        self._output_size = output_size
+
+        if rel_pos_type == "legacy":
+            if pos_enc_layer_type == "rel_pos":
+                pos_enc_layer_type = "legacy_rel_pos"
+            if selfattention_layer_type == "rel_selfattn":
+                selfattention_layer_type = "legacy_rel_selfattn"
+        elif rel_pos_type == "latest":
+            assert selfattention_layer_type != "legacy_rel_selfattn"
+            assert pos_enc_layer_type != "legacy_rel_pos"
+        else:
+            raise ValueError("unknown rel_pos_type: " + rel_pos_type)
+
+        activation = get_activation(activation_type)
+        if pos_enc_layer_type == "abs_pos":
+            pos_enc_class = PositionalEncoding
+        elif pos_enc_layer_type == "scaled_abs_pos":
+            pos_enc_class = ScaledPositionalEncoding
+        elif pos_enc_layer_type == "rel_pos":
+            assert selfattention_layer_type == "rel_selfattn"
+            pos_enc_class = RelPositionalEncoding
+        elif pos_enc_layer_type == "legacy_rel_pos":
+            assert selfattention_layer_type == "legacy_rel_selfattn"
+            pos_enc_class = LegacyRelPositionalEncoding
+            logging.warning(
+                "Using legacy_rel_pos and it will be deprecated in the future."
+            )
+        else:
+            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
+        
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif isinstance(input_layer, torch.nn.Module):
+            self.embed = torch.nn.Sequential(
+                input_layer,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer is None:
+            self.embed = torch.nn.Sequential(
+                pos_enc_class(output_size, positional_dropout_rate)
+            )
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+                activation,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+        elif selfattention_layer_type == "legacy_rel_selfattn":
+            assert pos_enc_layer_type == "legacy_rel_pos"
+            encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
+            encoder_selfattn_layer_args = (
+               attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+            logging.warning(
+                "Using legacy_rel_selfattn and it will be deprecated in the future."
+            )
+        elif selfattention_layer_type == "rel_selfattn":
+            assert pos_enc_layer_type == "rel_pos"
+            encoder_selfattn_layer = RelPositionMultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+                zero_triu,
+            )
+        else:
+            raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
+
+        convolution_layer = ConvolutionModule
+        convolution_layer_args = (output_size, cnn_module_kernel, activation)
+        encoder_selfattn_layer_raw = MultiHeadedAttention
+        encoder_selfattn_layer_args_raw = (
+            attention_heads,
+            output_size,
+            attention_dropout_rate,
+        )
+        self.encoders = repeat(
+            num_blocks,
+            lambda lnum: EncoderLayer(
+                output_size,
+                encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                positionwise_layer(*positionwise_layer_args) if macaron_style else None,
+                convolution_layer(*convolution_layer_args) if use_cnn_module else None,
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+        self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
+
+        self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
+
+        self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
+
+        self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        channel_size: torch.Tensor,
+        prev_states: torch.Tensor = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Calculate forward propagation.
+
+        Args:
+            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
+            ilens (torch.Tensor): Input length (#batch).
+            prev_states (torch.Tensor): Not to be used now.
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, L, output_size).
+            torch.Tensor: Output length (#batch).
+            torch.Tensor: Not to be used now.
+
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        if (
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+        xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
+        if isinstance(xs_pad, tuple):
+            xs_pad = xs_pad[0]
+
+        t_leng = xs_pad.size(1)
+        d_dim = xs_pad.size(2)
+        xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
+        #pdb.set_trace()
+        if(channel_size<8):
+            repeat_num = math.ceil(8/channel_size)
+            xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
+        xs_pad = self.conv1(xs_pad)
+        xs_pad = self.conv2(xs_pad)
+        xs_pad = self.conv3(xs_pad)
+        xs_pad = self.conv4(xs_pad)
+        xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
+        mask_tmp = masks.size(1)
+        masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        return xs_pad, olens, None
+    def forward_hidden(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Calculate forward propagation.
+
+        Args:
+            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
+            ilens (torch.Tensor): Input length (#batch).
+            prev_states (torch.Tensor): Not to be used now.
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, L, output_size).
+            torch.Tensor: Output length (#batch).
+            torch.Tensor: Not to be used now.
+
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        if (
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+        num_layer = len(self.encoders)
+        for idx, encoder in enumerate(self.encoders):
+            xs_pad, masks = encoder(xs_pad, masks)
+            if idx == num_layer // 2 - 1:
+                hidden_feature = xs_pad
+        if isinstance(xs_pad, tuple):
+            xs_pad = xs_pad[0]
+            hidden_feature = hidden_feature[0]
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+            self.hidden_feature = self.after_norm(hidden_feature)
+
+        olens = masks.squeeze(1).sum(1)
+        return xs_pad, olens, None

+ 125 - 0
funasr/models/frontend/default.py

@@ -131,3 +131,128 @@ class DefaultFrontend(AbsFrontend):
         # input_stft: (..., F, 2) -> (..., F)
         input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
         return input_stft, feats_lens
+
+
+
+
+class MultiChannelFrontend(AbsFrontend):
+    """Conventional frontend structure for ASR.
+
+    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
+    """
+
+    def __init__(
+            self,
+            fs: Union[int, str] = 16000,
+            n_fft: int = 512,
+            win_length: int = None,
+            hop_length: int = 128,
+            window: Optional[str] = "hann",
+            center: bool = True,
+            normalized: bool = False,
+            onesided: bool = True,
+            n_mels: int = 80,
+            fmin: int = None,
+            fmax: int = None,
+            htk: bool = False,
+            frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+            apply_stft: bool = True,
+            frame_length: int = None,
+            frame_shift: int = None,
+            lfr_m: int = None,
+            lfr_n: int = None,
+    ):
+        assert check_argument_types()
+        super().__init__()
+        if isinstance(fs, str):
+            fs = humanfriendly.parse_size(fs)
+
+        # Deepcopy (In general, dict shouldn't be used as default arg)
+        frontend_conf = copy.deepcopy(frontend_conf)
+        self.hop_length = hop_length
+
+        if apply_stft:
+            self.stft = Stft(
+                n_fft=n_fft,
+                win_length=win_length,
+                hop_length=hop_length,
+                center=center,
+                window=window,
+                normalized=normalized,
+                onesided=onesided,
+            )
+        else:
+            self.stft = None
+        self.apply_stft = apply_stft
+
+        if frontend_conf is not None:
+            self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
+        else:
+            self.frontend = None
+
+        self.logmel = LogMel(
+            fs=fs,
+            n_fft=n_fft,
+            n_mels=n_mels,
+            fmin=fmin,
+            fmax=fmax,
+            htk=htk,
+        )
+        self.n_mels = n_mels
+        self.frontend_type = "multichannelfrontend"
+
+    def output_size(self) -> int:
+        return self.n_mels
+
+    def forward(
+            self, input: torch.Tensor, input_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # 1. Domain-conversion: e.g. Stft: time -> time-freq
+        #import pdb;pdb.set_trace()
+        if self.stft is not None:
+            input_stft, feats_lens = self._compute_stft(input, input_lengths)
+        else:
+            if isinstance(input, ComplexTensor):
+                input_stft = input
+            else:
+                input_stft = ComplexTensor(input[..., 0], input[..., 1])
+            feats_lens = input_lengths
+        # 2. [Option] Speech enhancement
+        if self.frontend is not None:
+            assert isinstance(input_stft, ComplexTensor), type(input_stft)
+            # input_stft: (Batch, Length, [Channel], Freq)
+            input_stft, _, mask = self.frontend(input_stft, feats_lens)
+        # 4. STFT -> Power spectrum
+        # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
+        input_power = input_stft.real ** 2 + input_stft.imag ** 2
+
+        # 5. Feature transform e.g. Stft -> Log-Mel-Fbank
+        # input_power: (Batch, [Channel,] Length, Freq)
+        #       -> input_feats: (Batch, Length, Dim)
+        input_feats, _ = self.logmel(input_power, feats_lens)
+        bt = input_feats.size(0)
+        if input_feats.dim() ==4:
+            channel_size = input_feats.size(2)
+            # batch * channel * T * D
+            #pdb.set_trace()
+            input_feats = input_feats.transpose(1,2).reshape(bt*channel_size,-1,80).contiguous()
+            # input_feats = input_feats.transpose(1,2)
+            # batch * channel
+            feats_lens = feats_lens.repeat(1,channel_size).squeeze()
+        else:
+            channel_size = 1
+        return input_feats, feats_lens, channel_size
+
+    def _compute_stft(
+            self, input: torch.Tensor, input_lengths: torch.Tensor
+    ) -> torch.Tensor:
+        input_stft, feats_lens = self.stft(input, input_lengths)
+
+        assert input_stft.dim() >= 4, input_stft.shape
+        # "2" refers to the real/imag parts of Complex
+        assert input_stft.shape[-1] == 2, input_stft.shape
+
+        # Change torch.Tensor to ComplexTensor
+        # input_stft: (..., F, 2) -> (..., F)
+        input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
+        return input_stft, feats_lens

+ 0 - 3
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/.gitignore

@@ -1,3 +0,0 @@
-**/__pycache__
-*.onnx
-*.pyc

+ 12 - 2
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/README.md

@@ -20,9 +20,19 @@ cd funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer
    pip install -r requirements.txt
    ```
 3. Export the model.
-    - Export your model([docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
+   
+   - Export model from modelscope
+   ```shell
+   python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
+   ```
+   - Export model from local path
+   ```shell
+   python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true
+   ```
+    - More details ref to ([docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export))
+
 
-4. Run the demo.
+5. Run the demo.
    - Model_dir: the model path, which contains `model.onnx`, `config.yaml`, `am.mvn`.
    - Input: wav formt file, support formats: `str, np.ndarray, List[str]`
    - Output: `List[str]`: recognition result.

+ 9 - 0
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/demo.py

@@ -0,0 +1,9 @@
+from paraformer_onnx import Paraformer
+
+model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+model = Paraformer(model_dir, batch_size=1)
+
+wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav']
+
+result = model(wav_path)
+print(result)

+ 0 - 0
funasr/runtime/python/torchscripts/__init__.py


+ 0 - 0
funasr/runtime/python/torchscripts/paraformer/__init__.py


+ 9 - 1
funasr/tasks/abs_task.py

@@ -71,7 +71,7 @@ from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_int
 from funasr.utils.types import str_or_none
-from funasr.utils.wav_utils import calc_shape, generate_data_list
+from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
 from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
 
 try:
@@ -1153,6 +1153,14 @@ class AbsTask(ABC):
                 if args.batch_bins is not None:
                     args.batch_bins = args.batch_bins * args.ngpu
 
+        # filter samples if wav.scp and text are mismatch
+        if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
+            if not args.simple_ddp or distributed_option.dist_rank == 0:
+                filter_wav_text(args.data_dir, args.train_set)
+                filter_wav_text(args.data_dir, args.dev_set)
+            if args.simple_ddp:
+                dist.barrier()
+
         if args.train_shape_file is None and args.dataset_type == "small":
             if not args.simple_ddp or distributed_option.dist_rank == 0:
                 calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)

+ 138 - 0
funasr/tasks/asr.py

@@ -40,6 +40,7 @@ from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
 from funasr.models.e2e_asr import ESPnetASRModel
 from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_asr_mfcca import MFCCA
 from funasr.models.e2e_uni_asr import UniASR
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.conformer_encoder import ConformerEncoder
@@ -47,8 +48,10 @@ from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
 from funasr.models.encoder.rnn_encoder import RNNEncoder
 from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
 from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.default import MultiChannelFrontend
 from funasr.models.frontend.fused import FusedFrontends
 from funasr.models.frontend.s3prl import S3prlFrontend
 from funasr.models.frontend.wav_frontend import WavFrontend
@@ -86,6 +89,7 @@ frontend_choices = ClassChoices(
         s3prl=S3prlFrontend,
         fused=FusedFrontends,
         wav_frontend=WavFrontend,
+        multichannelfrontend=MultiChannelFrontend,
     ),
     type_check=AbsFrontend,
     default="default",
@@ -119,6 +123,7 @@ model_choices = ClassChoices(
         paraformer_bert=ParaformerBert,
         bicif_paraformer=BiCifParaformer,
         contextual_paraformer=ContextualParaformer,
+        mfcca=MFCCA,
     ),
     type_check=AbsESPnetModel,
     default="asr",
@@ -142,6 +147,7 @@ encoder_choices = ClassChoices(
         sanm=SANMEncoder,
         sanm_chunk_opt=SANMEncoderChunkOpt,
         data2vec_encoder=Data2VecEncoder,
+        mfcca_enc=MFCCAEncoder,
     ),
     type_check=AbsEncoder,
     default="rnn",
@@ -1106,3 +1112,135 @@ class ASRTaskParaformer(ASRTask):
         var_dict_torch_update.update(var_dict_torch_update_local)
 
         return var_dict_torch_update
+
+
+
+class ASRTaskMFCCA(ASRTask):
+    # If you need more than one optimizers, change this value
+    num_optimizers: int = 1
+
+    # Add variable objects configurations
+    class_choices_list = [
+        # --frontend and --frontend_conf
+        frontend_choices,
+        # --specaug and --specaug_conf
+        specaug_choices,
+        # --normalize and --normalize_conf
+        normalize_choices,
+        # --model and --model_conf
+        model_choices,
+        # --preencoder and --preencoder_conf
+        preencoder_choices,
+        # --encoder and --encoder_conf
+        encoder_choices,
+        # --decoder and --decoder_conf
+        decoder_choices,
+    ]
+
+    # If you need to modify train() or eval() procedures, change Trainer class here
+    trainer = Trainer
+
+    @classmethod
+    def build_model(cls, args: argparse.Namespace):
+        assert check_argument_types()
+        if isinstance(args.token_list, str):
+            with open(args.token_list, encoding="utf-8") as f:
+                token_list = [line.rstrip() for line in f]
+
+            # Overwriting token_list to keep it as "portable".
+            args.token_list = list(token_list)
+        elif isinstance(args.token_list, (tuple, list)):
+            token_list = list(args.token_list)
+        else:
+            raise RuntimeError("token_list must be str or list")
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size}")
+
+        # 1. frontend
+        if args.input_size is None:
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            if args.frontend == 'wav_frontend':
+                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+            else:
+                frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            args.frontend = None
+            args.frontend_conf = {}
+            frontend = None
+            input_size = args.input_size
+
+        # 2. Data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # 3. Normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(stats_file=args.cmvn_file,**args.normalize_conf)
+        else:
+            normalize = None
+
+        # 4. Pre-encoder input block
+        # NOTE(kan-bayashi): Use getattr to keep the compatibility
+        if getattr(args, "preencoder", None) is not None:
+            preencoder_class = preencoder_choices.get_class(args.preencoder)
+            preencoder = preencoder_class(**args.preencoder_conf)
+            input_size = preencoder.output_size()
+        else:
+            preencoder = None
+
+        # 5. Encoder
+        encoder_class = encoder_choices.get_class(args.encoder)
+        encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+        # 7. Decoder
+        decoder_class = decoder_choices.get_class(args.decoder)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder.output_size(),
+            **args.decoder_conf,
+        )
+
+        # 8. CTC
+        ctc = CTC(
+            odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
+        )
+
+
+        # 10. Build model
+        try:
+            model_class = model_choices.get_class(args.model)
+        except AttributeError:
+            model_class = model_choices.get_class("asr")
+
+        rnnt_decoder = None
+
+        # 8. Build model
+        model = model_class(
+            vocab_size=vocab_size,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            encoder=encoder,
+            decoder=decoder,
+            ctc=ctc,
+            rnnt_decoder=rnnt_decoder,
+            token_list=token_list,
+            **args.model_conf,
+        )
+
+        # 11. Initialize
+        if args.init is not None:
+            initialize(model, args.init)
+
+        assert check_return_type(model)
+        return model
+
+

+ 32 - 0
funasr/utils/wav_utils.py

@@ -287,3 +287,35 @@ def generate_data_list(data_dir, dataset, nj=100):
             wav_path = os.path.join(split_dir, str(i + 1), "wav.scp")
             text_path = os.path.join(split_dir, str(i + 1), "text")
             f_data.write(wav_path + " " + text_path + "\n")
+
+def filter_wav_text(data_dir, dataset):
+    wav_file = os.path.join(data_dir,dataset,"wav.scp")
+    text_file = os.path.join(data_dir, dataset, "text")
+    with open(wav_file) as f_wav, open(text_file) as f_text:
+        wav_lines = f_wav.readlines()
+        text_lines = f_text.readlines()
+    os.rename(wav_file, "{}.bak".format(wav_file))
+    os.rename(text_file, "{}.bak".format(text_file))
+    wav_dict = {}
+    for line in wav_lines:
+        parts = line.strip().split()
+        if len(parts) < 2:
+            continue
+        sample_name, wav_path = parts
+        wav_dict[sample_name] = wav_path
+    text_dict = {}
+    for line in text_lines:
+        parts = line.strip().split(" ", 1)
+        if len(parts) < 2:
+            continue
+        sample_name, txt = parts
+        text_dict[sample_name] = txt
+    filter_count = 0
+    with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
+        for sample_name, wav_path in wav_dict.items():
+            if sample_name in text_dict.keys():
+                f_wav.write(sample_name + " " + wav_path  + "\n")
+                f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
+            else:
+                filter_count += 1
+    print("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".format(len(wav_lines), filter_count, dataset))