speech_asr 2 лет назад
Родитель
Сommit
9f90bad3f5

+ 1 - 0
funasr/bin/train.py

@@ -25,6 +25,7 @@ def get_parser():
         help="The number of gpus. 0 indicates CPU mode",
     )
     parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
 
     # ddp related
     parser.add_argument(

+ 0 - 349
funasr/datasets/iterable_dataset_modelscope.py

@@ -1,349 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-# Part of the implementation is borrowed from espnet/espnet.
-"""Iterable dataset module."""
-import copy
-from io import StringIO
-from pathlib import Path
-from typing import Callable, Collection, Dict, Iterator, Tuple, Union
-
-import kaldiio
-import numpy as np
-import soundfile
-import torch
-from funasr.datasets.dataset import ESPnetDataset
-from torch.utils.data.dataset import IterableDataset
-from typeguard import check_argument_types
-
-from funasr.utils import wav_utils
-
-
-def load_kaldi(input):
-    retval = kaldiio.load_mat(input)
-    if isinstance(retval, tuple):
-        assert len(retval) == 2, len(retval)
-        if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
-            # sound scp case
-            rate, array = retval
-        elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
-            # Extended ark format case
-            array, rate = retval
-        else:
-            raise RuntimeError(
-                f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
-
-        # Multichannel wave fie
-        # array: (NSample, Channel) or (Nsample)
-
-    else:
-        # Normal ark case
-        assert isinstance(retval, np.ndarray), type(retval)
-        array = retval
-    return array
-
-
-DATA_TYPES = {
-    'sound':
-    lambda x: soundfile.read(x)[0],
-    'kaldi_ark':
-    load_kaldi,
-    'npy':
-    np.load,
-    'text_int':
-    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
-    'csv_int':
-    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
-    'text_float':
-    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
-                         ),
-    'csv_float':
-    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
-                         ),
-    'text':
-    lambda x: x,
-}
-
-
-class IterableESPnetDatasetModelScope(IterableDataset):
-    """Pytorch Dataset class for ESPNet.
-
-    Examples:
-        >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
-        ...                                  ('token_int', 'output', 'text_int')],
-        ...                                )
-        >>> for uid, data in dataset:
-        ...     data
-        {'input': per_utt_array, 'output': per_utt_array}
-    """
-    def __init__(self,
-                 path_name_type_list: Collection[Tuple[any, str, str]],
-                 preprocess: Callable[[str, Dict[str, np.ndarray]],
-                                      Dict[str, np.ndarray]] = None,
-                 float_dtype: str = 'float32',
-                 int_dtype: str = 'long',
-                 key_file: str = None,
-                 sample_rate: Union[dict, int] = 16000):
-        assert check_argument_types()
-        if len(path_name_type_list) == 0:
-            raise ValueError(
-                '1 or more elements are required for "path_name_type_list"')
-
-        self.preprocess = preprocess
-
-        self.float_dtype = float_dtype
-        self.int_dtype = int_dtype
-        self.key_file = key_file
-        self.sample_rate = sample_rate
-
-        self.debug_info = {}
-        non_iterable_list = []
-        self.path_name_type_list = []
-
-        path_list = path_name_type_list[0]
-        name = path_name_type_list[1]
-        _type = path_name_type_list[2]
-        if name in self.debug_info:
-            raise RuntimeError(f'"{name}" is duplicated for data-key')
-        self.debug_info[name] = path_list, _type
-        #        for path, name, _type in path_name_type_list:
-        for path in path_list:
-            self.path_name_type_list.append((path, name, _type))
-
-        if len(non_iterable_list) != 0:
-            # Some types doesn't support iterable mode
-            self.non_iterable_dataset = ESPnetDataset(
-                path_name_type_list=non_iterable_list,
-                preprocess=preprocess,
-                float_dtype=float_dtype,
-                int_dtype=int_dtype,
-            )
-        else:
-            self.non_iterable_dataset = None
-
-        self.apply_utt2category = False
-
-    def has_name(self, name) -> bool:
-        return name in self.debug_info
-
-    def names(self) -> Tuple[str, ...]:
-        return tuple(self.debug_info)
-
-    def __repr__(self):
-        _mes = self.__class__.__name__
-        _mes += '('
-        for name, (path, _type) in self.debug_info.items():
-            _mes += f'\n  {name}: {{"path": "{path}", "type": "{_type}"}}'
-        _mes += f'\n  preprocess: {self.preprocess})'
-        return _mes
-
-    def __iter__(
-            self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
-        torch.set_printoptions(profile='default')
-        count = len(self.path_name_type_list)
-        for idx in range(count):
-            # 2. Load the entry from each line and create a dict
-            data = {}
-            # 2.a. Load data streamingly
-
-            # value:  /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
-            value = self.path_name_type_list[idx][0]['file']
-            uid = self.path_name_type_list[idx][0]['key']
-            # name:  speech
-            name = self.path_name_type_list[idx][1]
-            _type = self.path_name_type_list[idx][2]
-            func = DATA_TYPES[_type]
-            array = func(value)
-
-            # 2.b. audio resample
-            if _type == 'sound':
-                audio_sr: int = 16000
-                model_sr: int = 16000
-                if isinstance(self.sample_rate, int):
-                    model_sr = self.sample_rate
-                else:
-                    if 'audio_sr' in self.sample_rate:
-                        audio_sr = self.sample_rate['audio_sr']
-                    if 'model_sr' in self.sample_rate:
-                        model_sr = self.sample_rate['model_sr']
-                array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
-            # array:  [ 1.25122070e-03  ... ]
-            data[name] = array
-
-            # 3. [Option] Apply preprocessing
-            #   e.g. espnet2.train.preprocessor:CommonPreprocessor
-            if self.preprocess is not None:
-                data = self.preprocess(uid, data)
-                # data:  {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
-            # 4. Force data-precision
-            for name in data:
-                # value is np.ndarray data
-                value = data[name]
-                if not isinstance(value, np.ndarray):
-                    raise RuntimeError(
-                        f'All values must be converted to np.ndarray object '
-                        f'by preprocessing, but "{name}" is still {type(value)}.'
-                    )
-
-                # Cast to desired type
-                if value.dtype.kind == 'f':
-                    value = value.astype(self.float_dtype)
-                elif value.dtype.kind == 'i':
-                    value = value.astype(self.int_dtype)
-                else:
-                    raise NotImplementedError(
-                        f'Not supported dtype: {value.dtype}')
-                data[name] = value
-
-            yield uid, data
-
-        if count == 0:
-            raise RuntimeError('No iteration')
-
-
-class IterableESPnetBytesModelScope(IterableDataset):
-    """Pytorch audio bytes class for ESPNet.
-
-    Examples:
-        >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
-        ...                                ('token_int', 'output', 'text_int')],
-        ...                                )
-        >>> for uid, data in dataset:
-        ...     data
-        {'input': per_utt_array, 'output': per_utt_array}
-    """
-    def __init__(self,
-                 path_name_type_list: Collection[Tuple[any, str, str]],
-                 preprocess: Callable[[str, Dict[str, np.ndarray]],
-                                      Dict[str, np.ndarray]] = None,
-                 float_dtype: str = 'float32',
-                 int_dtype: str = 'long',
-                 key_file: str = None,
-                 sample_rate: Union[dict, int] = 16000):
-        assert check_argument_types()
-        if len(path_name_type_list) == 0:
-            raise ValueError(
-                '1 or more elements are required for "path_name_type_list"')
-
-        self.preprocess = preprocess
-
-        self.float_dtype = float_dtype
-        self.int_dtype = int_dtype
-        self.key_file = key_file
-        self.sample_rate = sample_rate
-
-        self.debug_info = {}
-        non_iterable_list = []
-        self.path_name_type_list = []
-
-        audio_data = path_name_type_list[0]
-        name = path_name_type_list[1]
-        _type = path_name_type_list[2]
-        if name in self.debug_info:
-            raise RuntimeError(f'"{name}" is duplicated for data-key')
-        self.debug_info[name] = audio_data, _type
-        self.path_name_type_list.append((audio_data, name, _type))
-
-        if len(non_iterable_list) != 0:
-            # Some types doesn't support iterable mode
-            self.non_iterable_dataset = ESPnetDataset(
-                path_name_type_list=non_iterable_list,
-                preprocess=preprocess,
-                float_dtype=float_dtype,
-                int_dtype=int_dtype,
-            )
-        else:
-            self.non_iterable_dataset = None
-
-        self.apply_utt2category = False
-
-        if float_dtype == 'float32':
-            self.np_dtype = np.float32
-
-    def has_name(self, name) -> bool:
-        return name in self.debug_info
-
-    def names(self) -> Tuple[str, ...]:
-        return tuple(self.debug_info)
-
-    def __repr__(self):
-        _mes = self.__class__.__name__
-        _mes += '('
-        for name, (path, _type) in self.debug_info.items():
-            _mes += f'\n  {name}: {{"path": "{path}", "type": "{_type}"}}'
-        _mes += f'\n  preprocess: {self.preprocess})'
-        return _mes
-
-    def __iter__(
-            self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
-
-        torch.set_printoptions(profile='default')
-        # 2. Load the entry from each line and create a dict
-        data = {}
-        # 2.a. Load data streamingly
-
-        value = self.path_name_type_list[0][0]
-        uid = 'pcm_data'
-        # name:  speech
-        name = self.path_name_type_list[0][1]
-        _type = self.path_name_type_list[0][2]
-        func = DATA_TYPES[_type]
-        # array:  [ 1.25122070e-03  ... ]
-        #        data[name] = np.frombuffer(value, dtype=self.np_dtype)
-
-        # 2.b. byte(PCM16) to float32
-        middle_data = np.frombuffer(value, dtype=np.int16)
-        middle_data = np.asarray(middle_data)
-        if middle_data.dtype.kind not in 'iu':
-            raise TypeError("'middle_data' must be an array of integers")
-        dtype = np.dtype('float32')
-        if dtype.kind != 'f':
-            raise TypeError("'dtype' must be a floating point type")
-
-        i = np.iinfo(middle_data.dtype)
-        abs_max = 2**(i.bits - 1)
-        offset = i.min + abs_max
-        array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
-                              dtype=self.np_dtype)
-
-        # 2.c. audio resample
-        if _type == 'sound':
-            audio_sr: int = 16000
-            model_sr: int = 16000
-            if isinstance(self.sample_rate, int):
-                model_sr = self.sample_rate
-            else:
-                if 'audio_sr' in self.sample_rate:
-                    audio_sr = self.sample_rate['audio_sr']
-                if 'model_sr' in self.sample_rate:
-                    model_sr = self.sample_rate['model_sr']
-            array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
-        data[name] = array
-
-        # 3. [Option] Apply preprocessing
-        #   e.g. espnet2.train.preprocessor:CommonPreprocessor
-        if self.preprocess is not None:
-            data = self.preprocess(uid, data)
-            # data:  {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
-        # 4. Force data-precision
-        for name in data:
-            # value is np.ndarray data
-            value = data[name]
-            if not isinstance(value, np.ndarray):
-                raise RuntimeError(
-                    f'All values must be converted to np.ndarray object '
-                    f'by preprocessing, but "{name}" is still {type(value)}.')
-
-            # Cast to desired type
-            if value.dtype.kind == 'f':
-                value = value.astype(self.float_dtype)
-            elif value.dtype.kind == 'i':
-                value = value.astype(self.int_dtype)
-            else:
-                raise NotImplementedError(
-                    f'Not supported dtype: {value.dtype}')
-            data[name] = value
-
-        yield uid, data

+ 16 - 0
funasr/datasets/small_datasets/build_loader.py

@@ -0,0 +1,16 @@
+import torch
+from funasr.datasets.small_datasets.dataset import ESPnetDataset
+from funasr.datasets.small_datasets.build_preprocess import build_preprocess
+
+def build_dataloader(args):
+    if args.frontend_conf is not None:
+        dest_sample_rate = args.frontend_conf["fs"] if (args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
+    preprocess_fn = build_preprocess()
+    dataset = ESPnetDataset(
+        iter_options.data_path_and_name_and_type,
+        float_dtype=args.train_dtype,
+        preprocess=preprocess_fn,
+        max_cache_size=iter_options.max_cache_size,
+        max_cache_fd=iter_options.max_cache_fd,
+        dest_sample_rate=dest_sample_rate,
+    )

+ 35 - 208
funasr/datasets/small_datasets/dataset.py

@@ -1,15 +1,10 @@
 # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
 #  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 
-from abc import ABC
-from abc import abstractmethod
 import collections
 import copy
-import functools
 import logging
 import numbers
-import re
-from typing import Any
 from typing import Callable
 from typing import Collection
 from typing import Dict
@@ -17,7 +12,6 @@ from typing import Mapping
 from typing import Tuple
 from typing import Union
 
-import h5py
 import humanfriendly
 import kaldiio
 import numpy as np
@@ -27,10 +21,6 @@ from typeguard import check_argument_types
 from typeguard import check_return_type
 
 from funasr.fileio.npy_scp import NpyScpReader
-from funasr.fileio.rand_gen_dataset import FloatRandomGenerateDataset
-from funasr.fileio.rand_gen_dataset import IntRandomGenerateDataset
-from funasr.fileio.read_text import load_num_sequence_text
-from funasr.fileio.read_text import read_2column_text
 from funasr.fileio.sound_scp import SoundScpReader
 from funasr.utils.sized_dict import SizedDict
 
@@ -88,25 +78,6 @@ class AdapterForSoundScpReader(collections.abc.Mapping):
         return array
 
 
-class H5FileWrapper:
-    def __init__(self, path: str):
-        self.path = path
-        self.h5_file = h5py.File(path, "r")
-
-    def __repr__(self) -> str:
-        return str(self.h5_file)
-
-    def __len__(self) -> int:
-        return len(self.h5_file)
-
-    def __iter__(self):
-        return iter(self.h5_file)
-
-    def __getitem__(self, key) -> np.ndarray:
-        value = self.h5_file[key]
-        return value[()]
-
-
 def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
     # The file is as follows:
     #   utterance_id_A /some/where/a.wav
@@ -127,156 +98,22 @@ def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
     return AdapterForSoundScpReader(loader, float_dtype)
 
 
-def rand_int_loader(filepath, loader_type):
-    # e.g. rand_int_3_10
-    try:
-        low, high = map(int, loader_type[len("rand_int_") :].split("_"))
-    except ValueError:
-        raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}")
-    return IntRandomGenerateDataset(filepath, low, high)
-
-
-DATA_TYPES = {
-    "sound": dict(
-        func=sound_loader,
-        kwargs=["dest_sample_rate","float_dtype"],
-        help="Audio format types which supported by sndfile wav, flac, etc."
-        "\n\n"
-        "   utterance_id_a a.wav\n"
-        "   utterance_id_b b.wav\n"
-        "   ...",
-    ),
-    "kaldi_ark": dict(
-        func=kaldi_loader,
-        kwargs=["max_cache_fd"],
-        help="Kaldi-ark file type."
-        "\n\n"
-        "   utterance_id_A /some/where/a.ark:123\n"
-        "   utterance_id_B /some/where/a.ark:456\n"
-        "   ...",
-    ),
-    "npy": dict(
-        func=NpyScpReader,
-        kwargs=[],
-        help="Npy file format."
-        "\n\n"
-        "   utterance_id_A /some/where/a.npy\n"
-        "   utterance_id_B /some/where/b.npy\n"
-        "   ...",
-    ),
-    "text_int": dict(
-        func=functools.partial(load_num_sequence_text, loader_type="text_int"),
-        kwargs=[],
-        help="A text file in which is written a sequence of interger numbers "
-        "separated by space."
-        "\n\n"
-        "   utterance_id_A 12 0 1 3\n"
-        "   utterance_id_B 3 3 1\n"
-        "   ...",
-    ),
-    "csv_int": dict(
-        func=functools.partial(load_num_sequence_text, loader_type="csv_int"),
-        kwargs=[],
-        help="A text file in which is written a sequence of interger numbers "
-        "separated by comma."
-        "\n\n"
-        "   utterance_id_A 100,80\n"
-        "   utterance_id_B 143,80\n"
-        "   ...",
-    ),
-    "text_float": dict(
-        func=functools.partial(load_num_sequence_text, loader_type="text_float"),
-        kwargs=[],
-        help="A text file in which is written a sequence of float numbers "
-        "separated by space."
-        "\n\n"
-        "   utterance_id_A 12. 3.1 3.4 4.4\n"
-        "   utterance_id_B 3. 3.12 1.1\n"
-        "   ...",
-    ),
-    "csv_float": dict(
-        func=functools.partial(load_num_sequence_text, loader_type="csv_float"),
-        kwargs=[],
-        help="A text file in which is written a sequence of float numbers "
-        "separated by comma."
-        "\n\n"
-        "   utterance_id_A 12.,3.1,3.4,4.4\n"
-        "   utterance_id_B 3.,3.12,1.1\n"
-        "   ...",
-    ),
-    "text": dict(
-        func=read_2column_text,
-        kwargs=[],
-        help="Return text as is. The text must be converted to ndarray "
-        "by 'preprocess'."
-        "\n\n"
-        "   utterance_id_A hello world\n"
-        "   utterance_id_B foo bar\n"
-        "   ...",
-    ),
-    "hdf5": dict(
-        func=H5FileWrapper,
-        kwargs=[],
-        help="A HDF5 file which contains arrays at the first level or the second level."
-        "   >>> f = h5py.File('file.h5')\n"
-        "   >>> array1 = f['utterance_id_A']\n"
-        "   >>> array2 = f['utterance_id_B']\n",
-    ),
-    "rand_float": dict(
-        func=FloatRandomGenerateDataset,
-        kwargs=[],
-        help="Generate random float-ndarray which has the given shapes "
-        "in the file."
-        "\n\n"
-        "   utterance_id_A 3,4\n"
-        "   utterance_id_B 10,4\n"
-        "   ...",
-    ),
-    "rand_int_\\d+_\\d+": dict(
-        func=rand_int_loader,
-        kwargs=["loader_type"],
-        help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given "
-        "shapes in the path. "
-        "Give the lower and upper value by the file type. e.g. "
-        "rand_int_0_10 -> Generate integers from 0 to 10."
-        "\n\n"
-        "   utterance_id_A 3,4\n"
-        "   utterance_id_B 10,4\n"
-        "   ...",
-    ),
-}
-
-
-class AbsDataset(Dataset, ABC):
-    @abstractmethod
-    def has_name(self, name) -> bool:
-        raise NotImplementedError
-
-    @abstractmethod
-    def names(self) -> Tuple[str, ...]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]:
-        raise NotImplementedError
-
-
-class ESPnetDataset(AbsDataset):
+class ESPnetDataset(Dataset):
     """
-        Pytorch Dataset class for FunASR, simplied from ESPnet
+        Pytorch Dataset class for FunASR, modified from ESPnet
     """
 
     def __init__(
-        self,
-        path_name_type_list: Collection[Tuple[str, str, str]],
-        preprocess: Callable[
-            [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
-        ] = None,
-        float_dtype: str = "float32",
-        int_dtype: str = "long",
-        max_cache_size: Union[float, int, str] = 0.0,
-        max_cache_fd: int = 0,
-        dest_sample_rate: int = 16000,
+            self,
+            path_name_type_list: Collection[Tuple[str, str, str]],
+            preprocess: Callable[
+                [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+            ] = None,
+            float_dtype: str = "float32",
+            int_dtype: str = "long",
+            max_cache_size: Union[float, int, str] = 0.0,
+            max_cache_fd: int = 0,
+            dest_sample_rate: int = 16000,
     ):
         assert check_argument_types()
         if len(path_name_type_list) == 0:
@@ -304,8 +141,6 @@ class ESPnetDataset(AbsDataset):
             if len(self.loader_dict[name]) == 0:
                 raise RuntimeError(f"{path} has no samples")
 
-            # TODO(kamo): Should check consistency of each utt-keys?
-
         if isinstance(max_cache_size, str):
             max_cache_size = humanfriendly.parse_size(max_cache_size)
         self.max_cache_size = max_cache_size
@@ -315,43 +150,35 @@ class ESPnetDataset(AbsDataset):
             self.cache = None
 
     def _build_loader(
-        self, path: str, loader_type: str
+            self, path: str, loader_type: str
     ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
         """Helper function to instantiate Loader.
 
         Args:
             path:  The file path
-            loader_type:  loader_type. sound, npy, text_int, text_float, etc
+            loader_type:  loader_type. sound, npy, text, etc
         """
-        for key, dic in DATA_TYPES.items():
-            # e.g. loader_type="sound"
-            # -> return DATA_TYPES["sound"]["func"](path)
-            if re.match(key, loader_type):
-                kwargs = {}
-                for key2 in dic["kwargs"]:
-                    if key2 == "loader_type":
-                        kwargs["loader_type"] = loader_type
-                    elif key2 == "dest_sample_rate" and loader_type=="sound":
-                        kwargs["dest_sample_rate"] = self.dest_sample_rate
-                    elif key2 == "float_dtype":
-                        kwargs["float_dtype"] = self.float_dtype
-                    elif key2 == "int_dtype":
-                        kwargs["int_dtype"] = self.int_dtype
-                    elif key2 == "max_cache_fd":
-                        kwargs["max_cache_fd"] = self.max_cache_fd
-                    else:
-                        raise RuntimeError(f"Not implemented keyword argument: {key2}")
-
-                func = dic["func"]
-                try:
-                    return func(path, **kwargs)
-                except Exception:
-                    if hasattr(func, "__name__"):
-                        name = func.__name__
+        if loader_type == "sound":
+            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False)
+            return AdapterForSoundScpReader(loader, self.float_dtype)
+        elif loader_type == "kaldi_ark":
+            loader = kaldiio.load_scp(path, max_cache_fd=self.max_cache_fd)
+            return AdapterForSoundScpReader(loader, self.float_dtype)
+        elif loader_type == "npy":
+            return NpyScpReader()
+        elif loader_type == "text":
+            text_loader = {}
+            with open(path, "r", encoding="utf-8") as f:
+                for linenum, line in enumerate(f, 1):
+                    sps = line.rstrip().split(maxsplit=1)
+                    if len(sps) == 1:
+                        k, v = sps[0], ""
                     else:
-                        name = str(func)
-                    logging.error(f"An error happened with {name}({path})")
-                    raise
+                        k, v = sps
+                    if k in text_loader:
+                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+                    text_loader[k] = v
+            return text_loader
         else:
             raise RuntimeError(f"Not supported: loader_type={loader_type}")
 
@@ -392,7 +219,7 @@ class ESPnetDataset(AbsDataset):
                 if isinstance(value, (list, tuple)):
                     value = np.array(value)
                 if not isinstance(
-                    value, (np.ndarray, torch.Tensor, str, numbers.Number)
+                        value, (np.ndarray, torch.Tensor, str, numbers.Number)
                 ):
                     raise TypeError(
                         f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"

+ 826 - 0
funasr/datasets/small_datasets/preprocessor.py

@@ -0,0 +1,826 @@
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import scipy.signal
+import soundfile
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.token_id_converter import TokenIDConverter
+
+
+class AbsPreprocessor(ABC):
+    def __init__(self, train: bool):
+        self.train = train
+
+    @abstractmethod
+    def __call__(
+            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        raise NotImplementedError
+
+
+def forward_segment(text, dic):
+    word_list = []
+    i = 0
+    while i < len(text):
+        longest_word = text[i]
+        for j in range(i + 1, len(text) + 1):
+            word = text[i:j]
+            if word in dic:
+                if len(word) > len(longest_word):
+                    longest_word = word
+        word_list.append(longest_word)
+        i += len(longest_word)
+    return word_list
+
+
+def seg_tokenize(txt, seg_dict):
+    out_txt = ""
+    for word in txt:
+        if word in seg_dict:
+            out_txt += seg_dict[word] + " "
+        else:
+            out_txt += "<unk>" + " "
+    return out_txt.strip().split()
+
+
+def seg_tokenize_wo_pattern(txt, seg_dict):
+    out_txt = ""
+    for word in txt:
+        if word in seg_dict:
+            out_txt += seg_dict[word] + " "
+        else:
+            out_txt += "<unk>" + " "
+    return out_txt.strip().split()
+
+
+def framing(
+        x,
+        frame_length: int = 512,
+        frame_shift: int = 256,
+        centered: bool = True,
+        padded: bool = True,
+):
+    if x.size == 0:
+        raise ValueError("Input array size is zero")
+    if frame_length < 1:
+        raise ValueError("frame_length must be a positive integer")
+    if frame_length > x.shape[-1]:
+        raise ValueError("frame_length is greater than input length")
+    if 0 >= frame_shift:
+        raise ValueError("frame_shift must be greater than 0")
+
+    if centered:
+        pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
+            (frame_length // 2, frame_length // 2)
+        ]
+        x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+    if padded:
+        # Pad to integer number of windowed segments
+        # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
+        #  with integer nseg
+        nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
+        pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
+        x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+    # Created strided array of data segments
+    if frame_length == 1 and frame_length == frame_shift:
+        result = x[..., None]
+    else:
+        shape = x.shape[:-1] + (
+            (x.shape[-1] - frame_length) // frame_shift + 1,
+            frame_length,
+        )
+        strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
+        result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+    return result
+
+
+def detect_non_silence(
+        x: np.ndarray,
+        threshold: float = 0.01,
+        frame_length: int = 1024,
+        frame_shift: int = 512,
+        window: str = "boxcar",
+) -> np.ndarray:
+    """Power based voice activity detection.
+
+    Args:
+        x: (Channel, Time)
+    >>> x = np.random.randn(1000)
+    >>> detect = detect_non_silence(x)
+    >>> assert x.shape == detect.shape
+    >>> assert detect.dtype == np.bool
+    """
+    if x.shape[-1] < frame_length:
+        return np.full(x.shape, fill_value=True, dtype=np.bool)
+
+    if x.dtype.kind == "i":
+        x = x.astype(np.float64)
+    # framed_w: (C, T, F)
+    framed_w = framing(
+        x,
+        frame_length=frame_length,
+        frame_shift=frame_shift,
+        centered=False,
+        padded=True,
+    )
+    framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
+    # power: (C, T)
+    power = (framed_w ** 2).mean(axis=-1)
+    # mean_power: (C, 1)
+    mean_power = np.mean(power, axis=-1, keepdims=True)
+    if np.all(mean_power == 0):
+        return np.full(x.shape, fill_value=True, dtype=np.bool)
+    # detect_frames: (C, T)
+    detect_frames = power / mean_power > threshold
+    # detects: (C, T, F)
+    detects = np.broadcast_to(
+        detect_frames[..., None], detect_frames.shape + (frame_shift,)
+    )
+    # detects: (C, TF)
+    detects = detects.reshape(*detect_frames.shape[:-1], -1)
+    # detects: (C, TF)
+    return np.pad(
+        detects,
+        [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
+        mode="edge",
+    )
+
+
+class CommonPreprocessor(AbsPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: str = "text",
+            split_with_space: bool = False,
+            seg_dict_file: str = None,
+    ):
+        super().__init__(train)
+        self.train = train
+        self.speech_name = speech_name
+        self.text_name = text_name
+        self.speech_volume_normalize = speech_volume_normalize
+        self.rir_apply_prob = rir_apply_prob
+        self.noise_apply_prob = noise_apply_prob
+        self.split_with_space = split_with_space
+        self.seg_dict = None
+        if seg_dict_file is not None:
+            self.seg_dict = {}
+            with open(seg_dict_file) as f:
+                lines = f.readlines()
+            for line in lines:
+                s = line.strip().split()
+                key = s[0]
+                value = s[1:]
+                self.seg_dict[key] = " ".join(value)
+
+        if token_type is not None:
+            if token_list is None:
+                raise ValueError("token_list is required if token_type is not None")
+            self.text_cleaner = TextCleaner(text_cleaner)
+
+            self.tokenizer = build_tokenizer(
+                token_type=token_type,
+                bpemodel=bpemodel,
+                delimiter=delimiter,
+                space_symbol=space_symbol,
+                non_linguistic_symbols=non_linguistic_symbols,
+                g2p_type=g2p_type,
+            )
+            self.token_id_converter = TokenIDConverter(
+                token_list=token_list,
+                unk_symbol=unk_symbol,
+            )
+        else:
+            self.text_cleaner = None
+            self.tokenizer = None
+            self.token_id_converter = None
+
+        if train and rir_scp is not None:
+            self.rirs = []
+            with open(rir_scp, "r", encoding="utf-8") as f:
+                for line in f:
+                    sps = line.strip().split(None, 1)
+                    if len(sps) == 1:
+                        self.rirs.append(sps[0])
+                    else:
+                        self.rirs.append(sps[1])
+        else:
+            self.rirs = None
+
+        if train and noise_scp is not None:
+            self.noises = []
+            with open(noise_scp, "r", encoding="utf-8") as f:
+                for line in f:
+                    sps = line.strip().split(None, 1)
+                    if len(sps) == 1:
+                        self.noises.append(sps[0])
+                    else:
+                        self.noises.append(sps[1])
+            sps = noise_db_range.split("_")
+            if len(sps) == 1:
+                self.noise_db_low, self.noise_db_high = float(sps[0])
+            elif len(sps) == 2:
+                self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
+            else:
+                raise ValueError(
+                    "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
+                )
+        else:
+            self.noises = None
+
+    def _speech_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, Union[str, np.ndarray]]:
+        assert check_argument_types()
+        if self.speech_name in data:
+            if self.train and (self.rirs is not None or self.noises is not None):
+                speech = data[self.speech_name]
+                nsamples = len(speech)
+
+                # speech: (Nmic, Time)
+                if speech.ndim == 1:
+                    speech = speech[None, :]
+                else:
+                    speech = speech.T
+                # Calc power on non shlence region
+                power = (speech[detect_non_silence(speech)] ** 2).mean()
+
+                # 1. Convolve RIR
+                if self.rirs is not None and self.rir_apply_prob >= np.random.random():
+                    rir_path = np.random.choice(self.rirs)
+                    if rir_path is not None:
+                        rir, _ = soundfile.read(
+                            rir_path, dtype=np.float64, always_2d=True
+                        )
+
+                        # rir: (Nmic, Time)
+                        rir = rir.T
+
+                        # speech: (Nmic, Time)
+                        # Note that this operation doesn't change the signal length
+                        speech = scipy.signal.convolve(speech, rir, mode="full")[
+                                 :, : speech.shape[1]
+                                 ]
+                        # Reverse mean power to the original power
+                        power2 = (speech[detect_non_silence(speech)] ** 2).mean()
+                        speech = np.sqrt(power / max(power2, 1e-10)) * speech
+
+                # 2. Add Noise
+                if (
+                        self.noises is not None
+                        and self.noise_apply_prob >= np.random.random()
+                ):
+                    noise_path = np.random.choice(self.noises)
+                    if noise_path is not None:
+                        noise_db = np.random.uniform(
+                            self.noise_db_low, self.noise_db_high
+                        )
+                        with soundfile.SoundFile(noise_path) as f:
+                            if f.frames == nsamples:
+                                noise = f.read(dtype=np.float64, always_2d=True)
+                            elif f.frames < nsamples:
+                                offset = np.random.randint(0, nsamples - f.frames)
+                                # noise: (Time, Nmic)
+                                noise = f.read(dtype=np.float64, always_2d=True)
+                                # Repeat noise
+                                noise = np.pad(
+                                    noise,
+                                    [(offset, nsamples - f.frames - offset), (0, 0)],
+                                    mode="wrap",
+                                )
+                            else:
+                                offset = np.random.randint(0, f.frames - nsamples)
+                                f.seek(offset)
+                                # noise: (Time, Nmic)
+                                noise = f.read(
+                                    nsamples, dtype=np.float64, always_2d=True
+                                )
+                                if len(noise) != nsamples:
+                                    raise RuntimeError(f"Something wrong: {noise_path}")
+                        # noise: (Nmic, Time)
+                        noise = noise.T
+
+                        noise_power = (noise ** 2).mean()
+                        scale = (
+                                10 ** (-noise_db / 20)
+                                * np.sqrt(power)
+                                / np.sqrt(max(noise_power, 1e-10))
+                        )
+                        speech = speech + scale * noise
+
+                speech = speech.T
+                ma = np.max(np.abs(speech))
+                if ma > 1.0:
+                    speech /= ma
+                data[self.speech_name] = speech
+
+            if self.speech_volume_normalize is not None:
+                speech = data[self.speech_name]
+                ma = np.max(np.abs(speech))
+                data[self.speech_name] = speech * self.speech_volume_normalize / ma
+        assert check_return_type(data)
+        return data
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        if self.text_name in data and self.tokenizer is not None:
+            text = data[self.text_name]
+            text = self.text_cleaner(text)
+            if self.split_with_space:
+                tokens = text.strip().split(" ")
+                if self.seg_dict is not None:
+                    tokens = forward_segment("".join(tokens), self.seg_dict)
+                    tokens = seg_tokenize(tokens, self.seg_dict)
+            else:
+                tokens = self.tokenizer.text2tokens(text)
+            text_ints = self.token_id_converter.tokens2ids(tokens)
+            data[self.text_name] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
+    def __call__(
+            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        assert check_argument_types()
+
+        data = self._speech_process(data)
+        data = self._text_process(data)
+        return data
+
+
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: str = "text",
+            split_with_space: bool = False,
+            seg_dict_file: str = None,
+    ):
+        super().__init__(train,
+                         token_type,
+                         token_list,
+                         bpemodel,
+                         text_cleaner,
+                         g2p_type,
+                         unk_symbol,
+                         space_symbol,
+                         non_linguistic_symbols,
+                         delimiter,
+                         rir_scp,
+                         rir_apply_prob,
+                         noise_scp,
+                         noise_apply_prob,
+                         noise_db_range,
+                         speech_volume_normalize,
+                         speech_name,
+                         text_name,
+                         split_with_space,
+                         seg_dict_file,
+                         )
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        if self.text_name in data and self.tokenizer is not None:
+            text = data[self.text_name]
+            text = self.text_cleaner(text)
+            if self.split_with_space:
+                tokens = text.strip().split(" ")
+                if self.seg_dict is not None:
+                    tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+            else:
+                tokens = self.tokenizer.text2tokens(text)
+            text_ints = self.token_id_converter.tokens2ids(tokens)
+            data[self.text_name] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
+
+class CommonPreprocessor_multi(AbsPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            speech_name: str = "speech",
+            text_name: List[str] = ["text"],
+    ):
+        super().__init__(train)
+        self.train = train
+        self.speech_name = speech_name
+        self.text_name = text_name
+
+        if token_type is not None:
+            if token_list is None:
+                raise ValueError("token_list is required if token_type is not None")
+            self.text_cleaner = TextCleaner(text_cleaner)
+
+            self.tokenizer = build_tokenizer(
+                token_type=token_type,
+                bpemodel=bpemodel,
+                delimiter=delimiter,
+                space_symbol=space_symbol,
+                non_linguistic_symbols=non_linguistic_symbols,
+                g2p_type=g2p_type,
+            )
+            self.token_id_converter = TokenIDConverter(
+                token_list=token_list,
+                unk_symbol=unk_symbol,
+            )
+        else:
+            self.text_cleaner = None
+            self.tokenizer = None
+            self.token_id_converter = None
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        for text_n in self.text_name:
+            if text_n in data and self.tokenizer is not None:
+                text = data[text_n]
+                text = self.text_cleaner(text)
+                tokens = self.tokenizer.text2tokens(text)
+                text_ints = self.token_id_converter.tokens2ids(tokens)
+                data[text_n] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
+    def __call__(
+            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        assert check_argument_types()
+
+        if self.speech_name in data:
+            # Nothing now: candidates:
+            # - STFT
+            # - Fbank
+            # - CMVN
+            # - Data augmentation
+            pass
+
+        data = self._text_process(data)
+        return data
+
+
+class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: List[str] = [None],
+            token_list: List[Union[Path, str, Iterable[str]]] = [None],
+            bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: List[str] = ["text"],
+    ):
+        # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+        super().__init__(
+            train=train,
+            token_type=token_type[0],
+            token_list=token_list[0],
+            bpemodel=bpemodel[0],
+            text_cleaner=text_cleaner,
+            g2p_type=g2p_type,
+            unk_symbol=unk_symbol,
+            space_symbol=space_symbol,
+            non_linguistic_symbols=non_linguistic_symbols,
+            delimiter=delimiter,
+            speech_name=speech_name,
+            text_name=text_name[0],
+            rir_scp=rir_scp,
+            rir_apply_prob=rir_apply_prob,
+            noise_scp=noise_scp,
+            noise_apply_prob=noise_apply_prob,
+            noise_db_range=noise_db_range,
+            speech_volume_normalize=speech_volume_normalize,
+        )
+
+        assert (
+                len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+        ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+        self.num_tokenizer = len(token_type)
+        self.tokenizer = []
+        self.token_id_converter = []
+
+        for i in range(self.num_tokenizer):
+            if token_type[i] is not None:
+                if token_list[i] is None:
+                    raise ValueError("token_list is required if token_type is not None")
+
+                self.tokenizer.append(
+                    build_tokenizer(
+                        token_type=token_type[i],
+                        bpemodel=bpemodel[i],
+                        delimiter=delimiter,
+                        space_symbol=space_symbol,
+                        non_linguistic_symbols=non_linguistic_symbols,
+                        g2p_type=g2p_type,
+                    )
+                )
+                self.token_id_converter.append(
+                    TokenIDConverter(
+                        token_list=token_list[i],
+                        unk_symbol=unk_symbol,
+                    )
+                )
+            else:
+                self.tokenizer.append(None)
+                self.token_id_converter.append(None)
+
+        self.text_cleaner = TextCleaner(text_cleaner)
+        self.text_name = text_name  # override the text_name from CommonPreprocessor
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        for i in range(self.num_tokenizer):
+            text_name = self.text_name[i]
+            if text_name in data and self.tokenizer[i] is not None:
+                text = data[text_name]
+                text = self.text_cleaner(text)
+                tokens = self.tokenizer[i].text2tokens(text)
+                text_ints = self.token_id_converter[i].tokens2ids(tokens)
+                data[text_name] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
+
+class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: str = "text",
+            split_text_name: str = "split_text",
+            split_with_space: bool = False,
+            seg_dict_file: str = None,
+    ):
+        super().__init__(
+            train=train,
+            # Force to use word.
+            token_type="word",
+            token_list=token_list,
+            bpemodel=bpemodel,
+            text_cleaner=text_cleaner,
+            g2p_type=g2p_type,
+            unk_symbol=unk_symbol,
+            space_symbol=space_symbol,
+            non_linguistic_symbols=non_linguistic_symbols,
+            delimiter=delimiter,
+            speech_name=speech_name,
+            text_name=text_name,
+            rir_scp=rir_scp,
+            rir_apply_prob=rir_apply_prob,
+            noise_scp=noise_scp,
+            noise_apply_prob=noise_apply_prob,
+            noise_db_range=noise_db_range,
+            speech_volume_normalize=speech_volume_normalize,
+            split_with_space=split_with_space,
+            seg_dict_file=seg_dict_file,
+        )
+        # The data field name for split text.
+        self.split_text_name = split_text_name
+
+    @classmethod
+    def split_words(cls, text: str):
+        words = []
+        segs = text.split()
+        for seg in segs:
+            # There is no space in seg.
+            current_word = ""
+            for c in seg:
+                if len(c.encode()) == 1:
+                    # This is an ASCII char.
+                    current_word += c
+                else:
+                    # This is a Chinese char.
+                    if len(current_word) > 0:
+                        words.append(current_word)
+                        current_word = ""
+                    words.append(c)
+            if len(current_word) > 0:
+                words.append(current_word)
+        return words
+
+    def __call__(
+            self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
+    ) -> Dict[str, Union[list, np.ndarray]]:
+        assert check_argument_types()
+        # Split words.
+        if isinstance(data[self.text_name], str):
+            split_text = self.split_words(data[self.text_name])
+        else:
+            split_text = data[self.text_name]
+        data[self.text_name] = " ".join(split_text)
+        data = self._speech_process(data)
+        data = self._text_process(data)
+        data[self.split_text_name] = split_text
+        return data
+
+    def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
+        result = data[self.split_text_name]
+        del data[self.split_text_name]
+        return result
+
+
+class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: List[str] = [None],
+            token_list: List[Union[Path, str, Iterable[str]]] = [None],
+            bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: List[str] = ["text"],
+            vad_name: str = "vad_indexes",
+    ):
+        # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+        super().__init__(
+            train=train,
+            token_type=token_type[0],
+            token_list=token_list[0],
+            bpemodel=bpemodel[0],
+            text_cleaner=text_cleaner,
+            g2p_type=g2p_type,
+            unk_symbol=unk_symbol,
+            space_symbol=space_symbol,
+            non_linguistic_symbols=non_linguistic_symbols,
+            delimiter=delimiter,
+            speech_name=speech_name,
+            text_name=text_name[0],
+            rir_scp=rir_scp,
+            rir_apply_prob=rir_apply_prob,
+            noise_scp=noise_scp,
+            noise_apply_prob=noise_apply_prob,
+            noise_db_range=noise_db_range,
+            speech_volume_normalize=speech_volume_normalize,
+        )
+
+        assert (
+                len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+        ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+        self.num_tokenizer = len(token_type)
+        self.tokenizer = []
+        self.token_id_converter = []
+
+        for i in range(self.num_tokenizer):
+            if token_type[i] is not None:
+                if token_list[i] is None:
+                    raise ValueError("token_list is required if token_type is not None")
+
+                self.tokenizer.append(
+                    build_tokenizer(
+                        token_type=token_type[i],
+                        bpemodel=bpemodel[i],
+                        delimiter=delimiter,
+                        space_symbol=space_symbol,
+                        non_linguistic_symbols=non_linguistic_symbols,
+                        g2p_type=g2p_type,
+                    )
+                )
+                self.token_id_converter.append(
+                    TokenIDConverter(
+                        token_list=token_list[i],
+                        unk_symbol=unk_symbol,
+                    )
+                )
+            else:
+                self.tokenizer.append(None)
+                self.token_id_converter.append(None)
+
+        self.text_cleaner = TextCleaner(text_cleaner)
+        self.text_name = text_name  # override the text_name from CommonPreprocessor
+        self.vad_name = vad_name
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        for i in range(self.num_tokenizer):
+            text_name = self.text_name[i]
+            if text_name in data and self.tokenizer[i] is not None:
+                text = data[text_name]
+                text = self.text_cleaner(text)
+                tokens = self.tokenizer[i].text2tokens(text)
+                if "vad:" in tokens[-1]:
+                    vad = tokens[-1][4:]
+                    tokens = tokens[:-1]
+                    if len(vad) == 0:
+                        vad = -1
+                    else:
+                        vad = int(vad)
+                    data[self.vad_name] = np.array([vad], dtype=np.int64)
+                text_ints = self.token_id_converter[i].tokens2ids(tokens)
+                data[text_name] = np.array(text_ints, dtype=np.int64)
+
+
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+    assert word_limit > 1
+    if len(words) <= word_limit:
+        return [words]
+    sentences = []
+    length = len(words)
+    sentence_len = length // word_limit
+    for i in range(sentence_len):
+        sentences.append(words[i * word_limit:(i + 1) * word_limit])
+    if length % word_limit > 0:
+        sentences.append(words[sentence_len * word_limit:])
+    return sentences
+
+
+def build_preprocess(args):
+    if args.task_name == "asr":
+        pass
+    else:
+        raise ValueError(f"Not supported task={args.task_name}")