Procházet zdrojové kódy

Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main

雾聪 před 2 roky
rodič
revize
69dcdbcfc0

+ 2 - 1
.gitignore

@@ -18,4 +18,5 @@ dist
 build
 funasr.egg-info
 docs/_build
-modelscope
+modelscope
+samples

+ 7 - 4
README.md

@@ -34,9 +34,9 @@ For the release notes, please ref to [news](https://github.com/alibaba-damo-acad
 
 Install from pip
 ```shell
-pip install -U funasr
+pip3 install -U funasr
 # For the users in China, you could install with the command:
-# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
 ```
 
 Or install from source code
@@ -96,14 +96,17 @@ More examples could be found in [docs](https://alibaba-damo-academy.github.io/Fu
 ### runtime
 
 An example with websocket:
+
 For the server:
 ```shell
+cd funasr/runtime/python/websocket
 python wss_srv_asr.py --port 10095
 ```
+
 For the client:
 ```shell
-python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "5,10,5"
-#python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
+python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "5,10,5"
+#python wss_client_asr.py --host "127.0.0.1" --port 10095 --mode 2pass --chunk_size "8,8,4" --audio_in "./data/wav.scp" --output_dir "./results"
 ```
 More examples could be found in [docs](https://alibaba-damo-academy.github.io/FunASR/en/runtime/websocket_python.html#id2)
 ## Contact

+ 127 - 127
docs/benchmark/benchmark_pipeline_cer.md

@@ -45,156 +45,156 @@ bash infer.sh
 ### Chinese Dataset
 
 
-<table>
+<table border="1">
     <tr align="center">
-        <td>Model</td>
-        <td>Offline/Online</td>
-        <td colspan="2">Aishell1</td>
-        <td colspan="4">Aishell2</td>
-        <td colspan="3">WenetSpeech</td>
+        <td style="border: 1px solid">Model</td>
+        <td style="border: 1px solid">Offline/Online</td>
+        <td colspan="2" style="border: 1px solid">Aishell1</td>
+        <td colspan="4" style="border: 1px solid">Aishell2</td>
+        <td colspan="3" style="border: 1px solid">WenetSpeech</td>
     </tr>
     <tr align="center">
-        <td></td>
-        <td></td>
-        <td>dev</td> 
-        <td>test</td>
-        <td>dev_ios</td>
-        <td>test_ios</td>
-        <td>test_android</td>
-        <td>test_mic</td>
-        <td>dev</td>
-        <td>test_meeting</td>
-        <td>test_net</td>
+        <td style="border: 1px solid"></td>
+        <td style="border: 1px solid"></td>
+        <td style="border: 1px solid">dev</td> 
+        <td style="border: 1px solid">test</td>
+        <td style="border: 1px solid">dev_ios</td>
+        <td style="border: 1px solid">test_ios</td>
+        <td style="border: 1px solid">test_android</td>
+        <td style="border: 1px solid">test_mic</td>
+        <td style="border: 1px solid">dev</td>
+        <td style="border: 1px solid">test_meeting</td>
+        <td style="border: 1px solid">test_net</td>
     </tr>
     <tr align="center">
-        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large</a> </td>
-        <td>Offline</td>
-        <td>1.76</td>
-        <td>1.94</td>
-        <td>2.79</td>
-        <td>2.84</td>
-        <td>3.08</td>
-        <td>3.03</td>
-        <td>3.43</td>
-        <td>7.01</td>
-        <td>6.66</td>
+        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large</a> </td>
+        <td style="border: 1px solid">Offline</td>
+        <td style="border: 1px solid">1.76</td>
+        <td style="border: 1px solid">1.94</td>
+        <td style="border: 1px solid">2.79</td>
+        <td style="border: 1px solid">2.84</td>
+        <td style="border: 1px solid">3.08</td>
+        <td style="border: 1px solid">3.03</td>
+        <td style="border: 1px solid">3.43</td>
+        <td style="border: 1px solid">7.01</td>
+        <td style="border: 1px solid">6.66</td>
     </tr>
     <tr align="center">
-        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large-long</a> </td> 
-        <td>Offline</td>      
-        <td>1.80</td>
-        <td>2.10</td>
-        <td>2.78</td>
-        <td>2.87</td>
-        <td>3.12</td>
-        <td>3.11</td>
-        <td>3.44</td>
-        <td>13.28</td>
-        <td>7.08</td>
+        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary">Paraformer-large-long</a> </td> 
+        <td style="border: 1px solid">Offline</td>      
+        <td style="border: 1px solid">1.80</td>
+        <td style="border: 1px solid">2.10</td>
+        <td style="border: 1px solid">2.78</td>
+        <td style="border: 1px solid">2.87</td>
+        <td style="border: 1px solid">3.12</td>
+        <td style="border: 1px solid">3.11</td>
+        <td style="border: 1px solid">3.44</td>
+        <td style="border: 1px solid">13.28</td>
+        <td style="border: 1px solid">7.08</td>
     </tr>
     <tr align="center">
-        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary">Paraformer-large-contextual</a> </td>
-        <td>Offline</td>
-        <td>1.76</td>
-        <td>2.02</td>
-        <td>2.73</td>
-        <td>2.85</td>
-        <td>2.98</td>
-        <td>2.95</td>
-        <td>3.42</td>
-        <td>7.16</td>
-        <td>6.72</td>
+        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/summary">Paraformer-large-contextual</a> </td>
+        <td style="border: 1px solid">Offline</td>
+        <td style="border: 1px solid">1.76</td>
+        <td style="border: 1px solid">2.02</td>
+        <td style="border: 1px solid">2.73</td>
+        <td style="border: 1px solid">2.85</td>
+        <td style="border: 1px solid">2.98</td>
+        <td style="border: 1px solid">2.95</td>
+        <td style="border: 1px solid">3.42</td>
+        <td style="border: 1px solid">7.16</td>
+        <td style="border: 1px solid">6.72</td>
     </tr>
     <tr align="center">
-        <td> <a href="https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary">Paraformer</a> </td> 
-        <td>Offline</td>
-        <td>3.24</td>
-        <td>3.69</td>
-        <td>4.58</td>
-        <td>4.63</td>
-        <td>4.83</td>
-        <td>4.71</td>
-        <td>4.19</td>
-        <td>8.32</td>
-        <td>9.19</td>
+        <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1/summary">Paraformer</a> </td> 
+        <td style="border: 1px solid">Offline</td>
+        <td style="border: 1px solid">3.24</td>
+        <td style="border: 1px solid">3.69</td>
+        <td style="border: 1px solid">4.58</td>
+        <td style="border: 1px solid">4.63</td>
+        <td style="border: 1px solid">4.83</td>
+        <td style="border: 1px solid">4.71</td>
+        <td style="border: 1px solid">4.19</td>
+        <td style="border: 1px solid">8.32</td>
+        <td style="border: 1px solid">9.19</td>
     </tr>
    <tr align="center">
-        <td> <a href="https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary">UniASR</a> </td> 
-        <td>Online</td>
-        <td>3.34</td>
-        <td>3.99</td>
-        <td>4.62</td>
-        <td>4.52</td>
-        <td>4.77</td>
-        <td>4.73</td>
-        <td>4.51</td>
-        <td>10.63</td>
-        <td>9.70</td>
+        <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online/summary">UniASR</a> </td> 
+        <td style="border: 1px solid">Online</td>
+        <td style="border: 1px solid">3.34</td>
+        <td style="border: 1px solid">3.99</td>
+        <td style="border: 1px solid">4.62</td>
+        <td style="border: 1px solid">4.52</td>
+        <td style="border: 1px solid">4.77</td>
+        <td style="border: 1px solid">4.73</td>
+        <td style="border: 1px solid">4.51</td>
+        <td style="border: 1px solid">10.63</td>
+        <td style="border: 1px solid">9.70</td>
     </tr>
    <tr align="center">
-        <td> <a href="https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary">UniASR-large</a> </td> 
-        <td>Offline</td>      
-        <td>2.93</td>
-        <td>3.48</td>
-        <td>3.95</td>
-        <td>3.87</td>
-        <td>4.11</td>
-        <td>4.11</td>
-        <td>4.16</td>
-        <td>10.09</td>
-        <td>8.69</td>
+        <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline/summary">UniASR-large</a> </td> 
+        <td style="border: 1px solid">Offline</td>      
+        <td style="border: 1px solid">2.93</td>
+        <td style="border: 1px solid">3.48</td>
+        <td style="border: 1px solid">3.95</td>
+        <td style="border: 1px solid">3.87</td>
+        <td style="border: 1px solid">4.11</td>
+        <td style="border: 1px solid">4.11</td>
+        <td style="border: 1px solid">4.16</td>
+        <td style="border: 1px solid">10.09</td>
+        <td style="border: 1px solid">8.69</td>
     </tr>
     <tr align="center">
-        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary">Paraformer-aishell</a> </td>
-        <td>Offline</td>
-        <td>4.88</td>
-        <td>5.43</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
+        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-aishell1-pytorch/summary">Paraformer-aishell</a> </td>
+        <td style="border: 1px solid">Offline</td>
+        <td style="border: 1px solid">4.88</td>
+        <td style="border: 1px solid">5.43</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
     </tr>
    <tr align="center">
-        <td> <a href="https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary">ParaformerBert-aishell</a> </td>
-        <td>Offline</td>
-        <td>6.14</td>
-        <td>7.01</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
+        <td style="border: 1px solid"> <a href="https://modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell1-vocab4234-pytorch/summary">ParaformerBert-aishell</a> </td>
+        <td style="border: 1px solid">Offline</td>
+        <td style="border: 1px solid">6.14</td>
+        <td style="border: 1px solid">7.01</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
     </tr>
    <tr align="center">
-        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">Paraformer-aishell2</a> </td> 
-        <td>Offline</td>
-        <td>-</td>
-        <td>-</td>
-        <td>5.82</td>
-        <td>6.30</td>
-        <td>6.60</td>
-        <td>5.83</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
+        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformer_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">Paraformer-aishell2</a> </td> 
+        <td style="border: 1px solid">Offline</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">5.82</td>
+        <td style="border: 1px solid">6.30</td>
+        <td style="border: 1px solid">6.60</td>
+        <td style="border: 1px solid">5.83</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
     </tr>
    <tr align="center">
-        <td> <a href="https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">ParaformerBert-aishell2</a> </td> 
-        <td>Offline</td>
-        <td>-</td>
-        <td>-</td>
-        <td>4.95</td>
-        <td>5.45</td>
-        <td>5.59</td>
-        <td>5.83</td>
-        <td>-</td>
-        <td>-</td>
-        <td>-</td>
+        <td style="border: 1px solid"> <a href="https://www.modelscope.cn/models/damo/speech_paraformerbert_asr_nat-zh-cn-16k-aishell2-vocab5212-pytorch/summary">ParaformerBert-aishell2</a> </td> 
+        <td style="border: 1px solid">Offline</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">4.95</td>
+        <td style="border: 1px solid">5.45</td>
+        <td style="border: 1px solid">5.59</td>
+        <td style="border: 1px solid">5.83</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
+        <td style="border: 1px solid">-</td>
     </tr>
 </table>
 

+ 7 - 7
docs/installation/installation.md

@@ -32,7 +32,7 @@ Ref to [docs](https://docs.conda.io/en/latest/miniconda.html#windows-installers)
 ### Install Pytorch (version >= 1.11.0):
 
 ```sh
-pip install torch torchaudio
+pip3 install torch torchaudio
 ```
 If there exists CUDAs in your environments, you should install the pytorch with the version matching the CUDA. The matching list could be found in [docs](https://pytorch.org/get-started/previous-versions/).
 ### Install funasr
@@ -40,27 +40,27 @@ If there exists CUDAs in your environments, you should install the pytorch with
 #### Install from pip
 
 ```shell
-pip install -U funasr
+pip3 install -U funasr
 # For the users in China, you could install with the command:
-# pip install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
 ```
 
 #### Or install from source code
 
 ``` sh
 git clone https://github.com/alibaba/FunASR.git && cd FunASR
-pip install -e ./
+pip3 install -e ./
 # For the users in China, you could install with the command:
-# pip install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -e ./ -i https://mirror.sjtu.edu.cn/pypi/web/simple
 ```
 
 ### Install modelscope (Optional)
 If you want to use the pretrained models in ModelScope, you should install the modelscope:
 
 ```shell
-pip install -U modelscope
+pip3 install -U modelscope
 # For the users in China, you could install with the command:
-# pip install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
+# pip3 install -U modelscope -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
 ```
 
 ### FQA

+ 2 - 2
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml

@@ -6,7 +6,7 @@ encoder_conf:
       unified_model_training: true
       default_chunk_size: 16
       jitter_range: 4
-      left_chunk_size: 0
+      left_chunk_size: 1
       embed_vgg_like: false
       subsampling_factor: 4
       linear_units: 2048
@@ -51,7 +51,7 @@ use_amp: true
 # optimization related
 accum_grad: 1
 grad_clip: 5
-max_epoch: 200
+max_epoch: 120
 val_scheduler_criterion:
     - valid
     - loss

+ 8 - 1
funasr/bin/asr_inference_launch.py

@@ -19,6 +19,7 @@ from typing import Union
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 import yaml
 from typeguard import check_argument_types
 
@@ -863,7 +864,13 @@ def inference_paraformer_online(
             raw_inputs = _load_bytes(data_path_and_name_and_type[0])
             raw_inputs = torch.tensor(raw_inputs)
         if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
-            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+            try:
+                raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+            except:
+                raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
+                if raw_inputs.ndim == 2:
+                    raw_inputs = raw_inputs[:, 0]
+                raw_inputs = torch.tensor(raw_inputs)
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, np.ndarray):
                 raw_inputs = torch.tensor(raw_inputs)

+ 11 - 1
funasr/datasets/iterable_dataset.py

@@ -14,6 +14,7 @@ import kaldiio
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 from torch.utils.data.dataset import IterableDataset
 from typeguard import check_argument_types
 import os.path
@@ -66,8 +67,17 @@ def load_pcm(input):
         bytes = f.read()
     return load_bytes(bytes)
 
+def load_wav(input):
+    try:
+        return torchaudio.load(input)[0].numpy()
+    except:
+        waveform, _ = soundfile.read(input, dtype='float32')
+        if waveform.ndim == 2:
+            waveform = waveform[:, 0]
+        return np.expand_dims(waveform, axis=0)
+
 DATA_TYPES = {
-    "sound": lambda x: torchaudio.load(x)[0].numpy(),
+    "sound": load_wav,
     "pcm": load_pcm,
     "kaldi_ark": load_kaldi,
     "bytes": load_bytes,

+ 10 - 1
funasr/datasets/large_datasets/dataset.py

@@ -6,6 +6,8 @@ from functools import partial
 import torch
 import torch.distributed as dist
 import torchaudio
+import numpy as np
+import soundfile
 from kaldiio import ReadHelper
 from torch.utils.data import IterableDataset
 
@@ -123,7 +125,14 @@ class AudioDataset(IterableDataset):
                             sample_dict["key"] = key
                     elif data_type == "sound":
                         key, path = item.strip().split()
-                        waveform, sampling_rate = torchaudio.load(path)
+                        try:
+                            waveform, sampling_rate = torchaudio.load(path)
+                        except:
+                            waveform, sampling_rate = soundfile.read(path, dtype='float32')
+                            if waveform.ndim == 2:
+                                waveform = waveform[:, 0]
+                            waveform = np.expand_dims(waveform, axis=0)
+                            waveform = torch.tensor(waveform)
                         if self.frontend_conf is not None:
                             if sampling_rate != self.frontend_conf["fs"]:
                                 waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,

+ 0 - 0
funasr/datasets/small_datasets/__init__.py


+ 69 - 1
funasr/fileio/sound_scp.py

@@ -1,6 +1,6 @@
 import collections.abc
 from pathlib import Path
-from typing import Union
+from typing import List, Tuple, Union
 
 import random
 import numpy as np
@@ -13,6 +13,74 @@ import torchaudio
 
 from funasr.fileio.read_text import read_2column_text
 
+def soundfile_read(
+    wavs: Union[str, List[str]],
+    dtype=None,
+    always_2d: bool = False,
+    concat_axis: int = 1,
+    start: int = 0,
+    end: int = None,
+    return_subtype: bool = False,
+) -> Tuple[np.array, int]:
+    if isinstance(wavs, str):
+        wavs = [wavs]
+
+    arrays = []
+    subtypes = []
+    prev_rate = None
+    prev_wav = None
+    for wav in wavs:
+        with soundfile.SoundFile(wav) as f:
+            f.seek(start)
+            if end is not None:
+                frames = end - start
+            else:
+                frames = -1
+            if dtype == "float16":
+                array = f.read(
+                    frames,
+                    dtype="float32",
+                    always_2d=always_2d,
+                ).astype(dtype)
+            else:
+                array = f.read(frames, dtype=dtype, always_2d=always_2d)
+            rate = f.samplerate
+            subtype = f.subtype
+            subtypes.append(subtype)
+
+        if len(wavs) > 1 and array.ndim == 1 and concat_axis == 1:
+            # array: (Time, Channel)
+            array = array[:, None]
+
+        if prev_wav is not None:
+            if prev_rate != rate:
+                raise RuntimeError(
+                    f"'{prev_wav}' and '{wav}' have mismatched sampling rate: "
+                    f"{prev_rate} != {rate}"
+                )
+
+            dim1 = arrays[0].shape[1 - concat_axis]
+            dim2 = array.shape[1 - concat_axis]
+            if dim1 != dim2:
+                raise RuntimeError(
+                    "Shapes must match with "
+                    f"{1 - concat_axis} axis, but gut {dim1} and {dim2}"
+                )
+
+        prev_rate = rate
+        prev_wav = wav
+        arrays.append(array)
+
+    if len(arrays) == 1:
+        array = arrays[0]
+    else:
+        array = np.concatenate(arrays, axis=concat_axis)
+
+    if return_subtype:
+        return array, rate, subtypes
+    else:
+        return array, rate
+
 
 class SoundScpReader(collections.abc.Mapping):
     """Reader class for 'wav.scp'.

+ 50 - 5
funasr/models/encoder/conformer_encoder.py

@@ -1081,7 +1081,10 @@ class ConformerChunkEncoder(AbsEncoder):
         mask = make_source_mask(x_len).to(x.device)
 
         if self.unified_model_training:
-            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            if self.training:
+                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            else:
+                chunk_size = self.default_chunk_size
             x, mask = self.embed(x, mask, chunk_size)
             pos_enc = self.pos_enc(x)
             chunk_mask = make_chunk_mask(
@@ -1113,12 +1116,15 @@ class ConformerChunkEncoder(AbsEncoder):
 
         elif self.dynamic_chunk_training:
             max_len = x.size(1)
-            chunk_size = torch.randint(1, max_len, (1,)).item()
+            if self.training:
+                chunk_size = torch.randint(1, max_len, (1,)).item()
 
-            if chunk_size > (max_len * self.short_chunk_threshold):
-                chunk_size = max_len
+                if chunk_size > (max_len * self.short_chunk_threshold):
+                    chunk_size = max_len
+                else:
+                    chunk_size = (chunk_size % self.short_chunk_size) + 1
             else:
-                chunk_size = (chunk_size % self.short_chunk_size) + 1
+                chunk_size = self.default_chunk_size
 
             x, mask = self.embed(x, mask, chunk_size)
             pos_enc = self.pos_enc(x)
@@ -1147,6 +1153,45 @@ class ConformerChunkEncoder(AbsEncoder):
 
         return x, olens, None
 
+    def full_utt_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+        x, mask = self.embed(x, mask, None)
+        pos_enc = self.pos_enc(x)
+        x_utt = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=None,
+        )
+
+        if self.time_reduction_factor > 1:
+            x_utt = x_utt[:,::self.time_reduction_factor,:]
+        return x_utt
+
     def simu_chunk_forward(
         self,
         x: torch.Tensor,

+ 0 - 0
funasr/modules/eend_ola/utils/__init__.py


+ 11 - 10
funasr/modules/subsampling.py

@@ -427,6 +427,7 @@ class StreamingConvInput(torch.nn.Module):
         conv_size: Union[int, Tuple],
         subsampling_factor: int = 4,
         vgg_like: bool = True,
+        conv_kernel_size: int = 3,
         output_size: Optional[int] = None,
     ) -> None:
         """Construct a ConvInput object."""
@@ -436,14 +437,14 @@ class StreamingConvInput(torch.nn.Module):
                 conv_size1, conv_size2 = conv_size
 
                 self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((1, 2)),
-                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((1, 2)),
                 )
@@ -462,14 +463,14 @@ class StreamingConvInput(torch.nn.Module):
                 kernel_1 = int(subsampling_factor / 2)
 
                 self.conv = torch.nn.Sequential(
-                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((kernel_1, 2)),
-                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                     torch.nn.ReLU(),
                     torch.nn.MaxPool2d((2, 2)),
                 )
@@ -487,14 +488,14 @@ class StreamingConvInput(torch.nn.Module):
                 self.conv = torch.nn.Sequential(
                     torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
                     torch.nn.ReLU(),
-                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+                    torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
                     torch.nn.ReLU(),
                 )
 
                 output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
 
                 self.subsampling_factor = subsampling_factor
-                self.kernel_2 = 3
+                self.kernel_2 = conv_kernel_size
                 self.stride_2 = 1
 
                 self.create_new_mask = self.create_new_conv2d_mask

+ 5 - 1
funasr/utils/asr_utils.py

@@ -5,6 +5,7 @@ import struct
 from typing import Any, Dict, List, Union
 
 import torchaudio
+import soundfile
 import numpy as np
 import pkg_resources
 from modelscope.utils.logger import get_logger
@@ -135,7 +136,10 @@ def get_sr_from_wav(fname: str):
                 if support_audio_type == "pcm":
                     fs = None
                 else:
-                    audio, fs = torchaudio.load(fname)
+                    try:
+                        audio, fs = torchaudio.load(fname)
+                    except:
+                        audio, fs = soundfile.read(fname)
                 break
         if audio_type.rfind(".scp") >= 0:
             with open(fname, encoding="utf-8") as f:

+ 6 - 1
funasr/utils/prepare_data.py

@@ -7,6 +7,7 @@ import kaldiio
 import numpy as np
 import torch.distributed as dist
 import torchaudio
+import soundfile
 
 
 def filter_wav_text(data_dir, dataset):
@@ -42,7 +43,11 @@ def filter_wav_text(data_dir, dataset):
 
 
 def wav2num_frame(wav_path, frontend_conf):
-    waveform, sampling_rate = torchaudio.load(wav_path)
+    try:
+        waveform, sampling_rate = torchaudio.load(wav_path)
+    except:
+        waveform, sampling_rate = soundfile.read(wav_path)
+        waveform = np.expand_dims(waveform, axis=0)
     n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
     feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
     return n_frames, feature_dim

+ 38 - 0
funasr/utils/runtime_sdk_download_tool.py

@@ -0,0 +1,38 @@
+from pathlib import Path
+import os
+import argparse
+from funasr.utils.types import str2bool
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--model-name', type=str, required=True)
+parser.add_argument('--export-dir', type=str, required=True)
+parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
+parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
+parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
+parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+args = parser.parse_args()
+
+model_dir = args.model_name
+if not Path(args.model_name).exists():
+	from modelscope.hub.snapshot_download import snapshot_download
+	try:
+		model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir)
+	except:
+		raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
+			(model_dir)
+
+model_file = os.path.join(model_dir, 'model.onnx')
+if args.quantize:
+	model_file = os.path.join(model_dir, 'model_quant.onnx')
+if not os.path.exists(model_file):
+	print(".onnx is not exist, begin to export onnx")
+	from funasr.export.export_model import ModelExport
+	export_model = ModelExport(
+		cache_dir=args.export_dir,
+		onnx=True,
+		device="cpu",
+		quant=args.quantize,
+	)
+	export_model.export(model_dir)

+ 13 - 2
funasr/utils/wav_utils.py

@@ -11,6 +11,7 @@ import librosa
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 import torchaudio.compliance.kaldi as kaldi
 
 
@@ -162,7 +163,13 @@ def compute_fbank(wav_file,
         waveform = torch.from_numpy(waveform.reshape(1, -1))
     else:
         # load pcm from wav, and resample
-        waveform, audio_sr = torchaudio.load(wav_file)
+        try:
+            waveform, audio_sr = torchaudio.load(wav_file)
+        except:
+            waveform, audio_sr = soundfile.read(wav_file, dtype='float32')
+            if waveform.ndim == 2:
+                waveform = waveform[:, 0]
+            waveform = torch.tensor(np.expand_dims(waveform, axis=0))
         waveform = waveform * (1 << 15)
         waveform = torch_resample(waveform, audio_sr, model_sr)
 
@@ -181,7 +188,11 @@ def compute_fbank(wav_file,
 
 
 def wav2num_frame(wav_path, frontend_conf):
-    waveform, sampling_rate = torchaudio.load(wav_path)
+    try:
+        waveform, sampling_rate = torchaudio.load(wav_path)
+    except:
+        waveform, sampling_rate = soundfile.read(wav_path)
+        waveform = torch.tensor(np.expand_dims(waveform, axis=0))
     speech_length = (waveform.shape[1] / sampling_rate) * 1000.
     n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
     feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]

+ 1 - 1
funasr/version.txt

@@ -1 +1 @@
-0.6.3
+0.6.5

+ 1 - 1
setup.py

@@ -20,7 +20,7 @@ requirements = {
         "librosa",
         "jamo==0.4.1",  # For kss
         "PyYAML>=5.1.2",
-        "soundfile>=0.10.2",
+        "soundfile>=0.11.0",
         "h5py>=2.10.0",
         "kaldiio>=2.17.0",
         "torch_complex",