Просмотр исходного кода

add language model infer pipeline

wucong.lyb 3 лет назад
Родитель
Сommit
9e8a52153d

+ 2 - 1
funasr/bin/lm_calc_perplexity.py

@@ -56,7 +56,7 @@ def calc_perplexity(
     set_all_random_seed(seed)
 
     # 2. Build LM
-    model, train_args = LMTask.build_model_from_file(train_config, model_file, device)
+    model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
     # Wrape model to make model.nll() data-parallel
     wrapped_model = ForwardAdaptor(model, "nll")
     wrapped_model.to(dtype=getattr(torch, dtype)).eval()
@@ -111,6 +111,7 @@ def calc_perplexity(
                     utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
 
                 # Write PPL of each utts for debugging or analysis
+                writer["utt2nll"][key] = str(-_nll)
                 writer["utt2ppl"][key] = str(utt_ppl)
                 writer["utt2ntokens"][key] = str(ntoken)
 

+ 406 - 0
funasr/bin/lm_inference.py

@@ -0,0 +1,406 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from pathlib import Path
+import sys
+import os
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from torch.nn.parallel import data_parallel
+from typeguard import check_argument_types
+
+from funasr.tasks.lm import LMTask
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import float_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+
+def inference(
+    output_dir: str,
+    batch_size: int,
+    dtype: str,
+    ngpu: int,
+    seed: int,
+    num_workers: int,
+    log_level: Union[int, str],
+    train_config: Optional[str],
+    model_file: Optional[str],
+    log_base: Optional[float],
+    key_file: Optional[str] = None,
+    allow_variable_data_keys: bool = False,
+    split_with_space: Optional[bool] = False,
+    seg_dict_file: Optional[str] = None,
+    data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+    raw_inputs: Union[List[Any], bytes, str] = None,
+    **kwargs,
+):
+    inference_pipeline = inference_modelscope(
+        output_dir=output_dir,
+        raw_inputs=raw_inputs,
+        batch_size=batch_size,
+        dtype=dtype,
+        ngpu=ngpu,
+        seed=seed,
+        num_workers=num_workers,
+        log_level=log_level,
+        key_file=key_file,
+        train_config=train_config,
+        model_file=model_file,
+        log_base = log_base,
+        allow_variable_data_keys = allow_variable_data_keys,
+        split_with_space=split_with_space,
+        seg_dict_file=seg_dict_file,
+        **kwargs,
+    )
+    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+
+def inference_modelscope(
+    batch_size: int,
+    dtype: str,
+    ngpu: int,
+    seed: int,
+    num_workers: int,
+    log_level: Union[int, str],
+    key_file: Optional[str],
+    train_config: Optional[str],
+    model_file: Optional[str],
+    log_base: Optional[float] = 10,
+    allow_variable_data_keys: bool = False,
+    split_with_space: Optional[bool] = False,
+    seg_dict_file: Optional[str] = None,
+    output_dir: Optional[str] = None,
+    param_dict: dict = None,
+    **kwargs,
+):
+    assert check_argument_types()
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+
+    # 2. Build Model
+    model, train_args = LMTask.build_model_from_file(
+        train_config, model_file, device)
+    wrapped_model = ForwardAdaptor(model, "nll")
+    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+    logging.info(f"Model:\n{model}")
+
+    preprocessor = LMPreprocessor(
+        train=False,
+        token_type=train_args.token_type,
+        token_list=train_args.token_list,
+        bpemodel=train_args.bpemodel,
+        text_cleaner=train_args.cleaner,
+        g2p_type=train_args.g2p,
+        text_name="text",
+        non_linguistic_symbols=train_args.non_linguistic_symbols,
+        split_with_space=split_with_space,
+        seg_dict_file=seg_dict_file
+    )
+
+    def _forward(
+        data_path_and_name_and_type,
+        raw_inputs: Union[List[Any], bytes, str] = None,
+        output_dir_v2: Optional[str] = None,
+        param_dict: dict = None,
+    ):
+        results = []
+        if output_dir_v2 is not None:
+            writer = DatadirWriter(output_dir_v2)
+        else:
+            writer = None
+
+        if raw_inputs != None:
+            line = raw_inputs.strip()
+            key = "lm demo"
+            if line=="":
+                item = {'key': key, 'value': ""}
+                results.append(item)
+                return results
+            batch = {}
+            batch['text'] = line
+            if preprocessor != None:
+                batch = preprocessor(key, batch)
+            
+            #  Force data-precision
+            for name in batch:
+                value = batch[name]
+                if not isinstance(value, np.ndarray):
+                    raise RuntimeError(
+                        f"All values must be converted to np.ndarray object "
+                        f'by preprocessing, but "{name}" is still {type(value)}.'
+                    )
+                # Cast to desired type
+                if value.dtype.kind == "f":
+                    value = value.astype("float32")
+                elif value.dtype.kind == "i":
+                    value = value.astype("long")
+                else:
+                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+                batch[name] = value
+            
+            batch["text_lengths"] = torch.from_numpy(
+                np.array([len(batch["text"])], dtype='int32'))
+            batch["text"] = np.expand_dims(batch["text"], axis=0)
+
+            with torch.no_grad():
+                batch = to_device(batch, device)
+                if ngpu <= 1:
+                    nll, lengths = wrapped_model(**batch)
+                else:
+                    nll, lengths = data_parallel(
+                        wrapped_model, (), range(ngpu), module_kwargs=batch
+                    )
+                ## compute ppl
+                ppl_out_batch = ""
+                ids2tokens = preprocessor.token_id_converter.ids2tokens
+                for sent_ids, sent_nll in zip(batch['text'], nll):
+                    pre_word = "<s>"
+                    cur_word = None
+                    sent_lst = ids2tokens(sent_ids) + ['</s>']
+                    ppl_out = " ".join(sent_lst) + "\n"
+                    for word, word_nll in zip(sent_lst, sent_nll):
+                        cur_word = word
+                        word_nll = -word_nll.cpu()
+                        if log_base is None:
+                            word_prob = np.exp(word_nll)
+                        else:
+                            word_prob = log_base ** (word_nll / np.log(log_base))
+                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+                            cur=cur_word, 
+                            pre=pre_word, 
+                            prob=round(word_prob.item(), 8),
+                            word_nll=round(word_nll.item(), 8)
+                            )
+                        pre_word = cur_word
+                    
+                    sent_nll_mean = sent_nll.mean().cpu().numpy()
+                    sent_nll_sum = sent_nll.sum().cpu().numpy()
+                    if log_base is None:
+                        sent_ppl = np.exp(sent_nll_mean)
+                    else:
+                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+                        sent_nll=round(-sent_nll_sum.item(), 4),
+                        sent_ppl=round(sent_ppl.item(), 4)
+                        )
+                    ppl_out_batch += ppl_out
+                    item = {'key': key, 'value': ppl_out}
+                    if writer is not None:
+                        writer["ppl"][key+":\n"] = ppl_out
+                    results.append(item)
+
+            return results
+                
+        # 3. Build data-iterator
+        loader = LMTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=preprocessor,
+            collate_fn=LMTask.build_collate_fn(train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
+
+        # 4. Start for-loop
+        total_nll = 0.0
+        total_ntokens = 0
+        ppl_out_all = ""
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+            ppl_out_batch = ""
+            with torch.no_grad():
+                batch = to_device(batch, device)
+                if ngpu <= 1:
+                    # NOTE(kamo): data_parallel also should work with ngpu=1,
+                    # but for debuggability it's better to keep this block.
+                    nll, lengths = wrapped_model(**batch)
+                else:
+                    nll, lengths = data_parallel(
+                        wrapped_model, (), range(ngpu), module_kwargs=batch
+                    )
+                ## print ppl
+                ids2tokens = preprocessor.token_id_converter.ids2tokens
+                for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
+                    pre_word = "<s>"
+                    cur_word = None
+                    sent_lst = ids2tokens(sent_ids) + ['</s>']
+                    ppl_out = " ".join(sent_lst) + "\n"
+                    for word, word_nll in zip(sent_lst, sent_nll):
+                        cur_word = word
+                        word_nll = -word_nll.cpu()
+                        if log_base is None:
+                            word_prob = np.exp(word_nll)
+                        else:
+                            word_prob = log_base ** (word_nll / np.log(log_base))
+                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+                            cur=cur_word, 
+                            pre=pre_word, 
+                            prob=round(word_prob.item(), 8),
+                            word_nll=round(word_nll.item(), 8)
+                            )
+                        pre_word = cur_word
+                    
+                    sent_nll_mean = sent_nll.mean().cpu().numpy()
+                    sent_nll_sum = sent_nll.sum().cpu().numpy()
+                    if log_base is None:
+                        sent_ppl = np.exp(sent_nll_mean)
+                    else:
+                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+                        sent_nll=round(-sent_nll_sum.item(), 4),
+                        sent_ppl=round(sent_ppl.item(), 4)
+                        )
+                    ppl_out_batch += ppl_out
+                    utt2nll = round(-sent_nll_sum.item(), 5)
+                    item = {'key': key, 'value': ppl_out}
+                    if writer is not None:
+                        writer["ppl"][key+":\n"] = ppl_out
+                        writer["utt2nll"][key] = str(utt2nll)
+                    results.append(item)
+
+            ppl_out_all += ppl_out_batch
+            
+            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
+            # nll: (B, L) -> (B,)
+            nll = nll.detach().cpu().numpy().sum(1)
+            # lengths: (B,)
+            lengths = lengths.detach().cpu().numpy()
+            total_nll += nll.sum()
+            total_ntokens += lengths.sum()
+
+        if log_base is None:
+            ppl = np.exp(total_nll / total_ntokens)
+        else:
+            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
+
+        avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
+            total_nll=round(-total_nll.item(), 4),
+            total_ppl=round(ppl.item(), 4)
+            )
+        item = {'key': 'AVG PPL', 'value': avg_ppl}
+        ppl_out_all += avg_ppl
+        if writer is not None:
+            writer["ppl"]["AVG PPL : "] = avg_ppl
+        results.append(item)
+
+        return results
+
+    return _forward
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="Calc perplexity",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, required=False)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    parser.add_argument(
+        "--log_base",
+        type=float_or_none,
+        default=10,
+        help="The base of logarithm for Perplexity. "
+             "If None, napier's constant is used.",
+        required=False
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        action="append",
+        required=False
+    )
+    group.add_argument(
+        "--raw_inputs",
+        type=str,
+        required=False
+    )
+    group.add_argument("--key_file", type=str_or_none)
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+    group.add_argument("--split_with_space", type=str2bool, default=False)
+    group.add_argument("--seg_dict_file", type=str_or_none)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument("--train_config", type=str)
+    group.add_argument("--model_file", type=str)
+
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    inference(**kwargs)
+
+if __name__ == "__main__":
+    main()
+

+ 130 - 0
funasr/bin/lm_inference_launch.py

@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+import argparse
+import logging
+import os
+import sys
+from typing import Union, Dict, Any
+
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils.types import float_or_none
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="Calc perplexity",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+    parser.add_argument("--output_dir", type=str, required=True)
+    parser.add_argument("--gpuid_list", type=str, required=True)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument("--njob", type=int, default=1, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    parser.add_argument(
+        "--log_base",
+        type=float_or_none,
+        default=10,
+        help="The base of logarithm for Perplexity. "
+             "If None, napier's constant is used.",
+        required=False
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        action="append",
+        required=False
+    )
+    group.add_argument(
+        "--raw_inputs",
+        type=str,
+        required=False
+    )
+    group.add_argument("--key_file", type=str_or_none)
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+    group.add_argument("--split_with_space", type=str2bool, default=False)
+    group.add_argument("--seg_dict_file", type=str_or_none)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument("--train_config", type=str)
+    group.add_argument("--model_file", type=str)
+    group.add_argument("--mode", type=str, default="lm")
+    return parser
+
+def inference_launch(mode, **kwargs):
+    if mode == "transformer":
+        from funasr.bin.lm_inference import inference_modelscope
+        return inference_modelscope(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    kwargs.pop("config", None)
+
+    # set logging messages
+    logging.basicConfig(
+        level=args.log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    logging.info("Decoding args: {}".format(kwargs))
+
+    # gpu setting
+    if args.ngpu > 0:
+        jobid = int(args.output_dir.split(".")[-1])
+        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+    kwargs.pop("gpuid_list", None)
+    kwargs.pop("njob", None)
+    results = inference_launch(**kwargs)
+
+
+if __name__ == "__main__":
+    main()
+

+ 41 - 17
funasr/bin/lm_train.py

@@ -1,22 +1,46 @@
 #!/usr/bin/env python3
-from funasr.tasks.lm import LMTask
-
-
-def get_parser():
-    parser = LMTask.get_parser()
-    return parser
 
+import os
 
-def main(cmd=None):
-    """LM training.
-
-    Example:
-
-        % python lm_train.py asr --print_config --optim adadelta
-        % python lm_train.py --config conf/train_asr.yaml
-    """
-    LMTask.main(cmd=cmd)
+from funasr.tasks.lm import LMTask
 
 
-if __name__ == "__main__":
-    main()
+# for LM Training
+def parse_args():
+    parser = LMTask.get_parser()
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def main(args=None, cmd=None):
+    # for LM Training
+    LMTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    # setup local gpu_id
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+    # DDP settings
+    if args.ngpu > 1:
+        args.distributed = True
+    else:
+        args.distributed = False
+    assert args.num_worker_count == 1
+
+    # re-compute batch size: when dataset type is small
+    if args.dataset_type == "small" and args.ngpu != 0:
+        if args.batch_size is not None:
+            args.batch_size = args.batch_size * args.ngpu
+        if args.batch_bins is not None:
+            args.batch_bins = args.batch_bins * args.ngpu
+
+    main(args=args)

+ 283 - 0
funasr/bin/tokenize_text.py

@@ -0,0 +1,283 @@
+#!/usr/bin/env python3
+import argparse
+from collections import Counter
+import logging
+from pathlib import Path
+import sys
+from typing import List
+from typing import Optional
+
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+
+def field2slice(field: Optional[str]) -> slice:
+    """Convert field string to slice
+
+    Note that field string accepts 1-based integer.
+
+    Examples:
+        >>> field2slice("1-")
+        slice(0, None, None)
+        >>> field2slice("1-3")
+        slice(0, 3, None)
+        >>> field2slice("-3")
+        slice(None, 3, None)
+    """
+    field = field.strip()
+    try:
+        if "-" in field:
+            # e.g. "2-" or "2-5" or "-7"
+            s1, s2 = field.split("-", maxsplit=1)
+            if s1.strip() == "":
+                s1 = None
+            else:
+                s1 = int(s1)
+                if s1 == 0:
+                    raise ValueError("1-based string")
+            if s2.strip() == "":
+                s2 = None
+            else:
+                s2 = int(s2)
+        else:
+            # e.g. "2"
+            s1 = int(field)
+            s2 = s1 + 1
+            if s1 == 0:
+                raise ValueError("must be 1 or more value")
+    except ValueError:
+        raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
+
+    if s1 is None:
+        slic = slice(None, s2)
+    else:
+        # -1 because of 1-based integer following "cut" command
+        # e.g "1-3" -> slice(0, 3)
+        slic = slice(s1 - 1, s2)
+    return slic
+
+
+def tokenize(
+    input: str,
+    output: str,
+    field: Optional[str],
+    delimiter: Optional[str],
+    token_type: str,
+    space_symbol: str,
+    non_linguistic_symbols: Optional[str],
+    bpemodel: Optional[str],
+    log_level: str,
+    write_vocabulary: bool,
+    vocabulary_size: int,
+    remove_non_linguistic_symbols: bool,
+    cutoff: int,
+    add_symbol: List[str],
+    cleaner: Optional[str],
+    g2p: Optional[str],
+):
+    assert check_argument_types()
+
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    if input == "-":
+        fin = sys.stdin
+    else:
+        fin = Path(input).open("r", encoding="utf-8")
+    if output == "-":
+        fout = sys.stdout
+    else:
+        p = Path(output)
+        p.parent.mkdir(parents=True, exist_ok=True)
+        fout = p.open("w", encoding="utf-8")
+
+    cleaner = TextCleaner(cleaner)
+    tokenizer = build_tokenizer(
+        token_type=token_type,
+        bpemodel=bpemodel,
+        delimiter=delimiter,
+        space_symbol=space_symbol,
+        non_linguistic_symbols=non_linguistic_symbols,
+        remove_non_linguistic_symbols=remove_non_linguistic_symbols,
+        g2p_type=g2p,
+    )
+
+    counter = Counter()
+    if field is not None:
+        field = field2slice(field)
+
+    for line in fin:
+        line = line.rstrip()
+        if field is not None:
+            # e.g. field="2-"
+            # uttidA hello world!! -> hello world!!
+            tokens = line.split(delimiter)
+            tokens = tokens[field]
+            if delimiter is None:
+                line = " ".join(tokens)
+            else:
+                line = delimiter.join(tokens)
+
+        line = cleaner(line)
+        tokens = tokenizer.text2tokens(line)
+        if not write_vocabulary:
+            fout.write(" ".join(tokens) + "\n")
+        else:
+            for t in tokens:
+                counter[t] += 1
+
+    if not write_vocabulary:
+        return
+    
+    ## FIXME
+    ## del duplicate add_symbols in counter
+    for symbol_and_id in add_symbol:
+        # e.g symbol="<blank>:0"
+        try:
+            symbol, idx = symbol_and_id.split(":")
+        except ValueError:
+            raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
+        symbol = symbol.strip()
+        if symbol in counter:
+            del counter[symbol]
+
+    # ======= write_vocabulary mode from here =======
+    # Sort by the number of occurrences in descending order
+    # and filter lower frequency words than cutoff value
+    words_and_counts = list(
+        filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
+    )
+    # Restrict the vocabulary size
+    if vocabulary_size > 0:
+        if vocabulary_size < len(add_symbol):
+            raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
+        words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
+
+    # Parse the values of --add_symbol
+    for symbol_and_id in add_symbol:
+        # e.g symbol="<blank>:0"
+        try:
+            symbol, idx = symbol_and_id.split(":")
+            idx = int(idx)
+        except ValueError:
+            raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
+        symbol = symbol.strip()
+
+        # e.g. idx=0  -> append as the first symbol
+        # e.g. idx=-1 -> append as the last symbol
+        if idx < 0:
+            idx = len(words_and_counts) + 1 + idx
+        words_and_counts.insert(idx, (symbol, None))
+
+    # Write words
+    for w, c in words_and_counts:
+        fout.write(w + "\n")
+
+    # Logging
+    total_count = sum(counter.values())
+    invocab_count = sum(c for w, c in words_and_counts if c is not None)
+    logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        description="Tokenize texts",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument(
+        "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
+    )
+    parser.add_argument(
+        "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
+    )
+    parser.add_argument(
+        "--field",
+        "-f",
+        help="The target columns of the input text as 1-based integer. e.g 2-",
+    )
+    parser.add_argument(
+        "--token_type",
+        "-t",
+        default="char",
+        choices=["char", "bpe", "word", "phn"],
+        help="Token type",
+    )
+    parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
+    parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
+    parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
+    parser.add_argument(
+        "--non_linguistic_symbols",
+        type=str_or_none,
+        help="non_linguistic_symbols file path",
+    )
+    parser.add_argument(
+        "--remove_non_linguistic_symbols",
+        type=str2bool,
+        default=False,
+        help="Remove non-language-symbols from tokens",
+    )
+    parser.add_argument(
+        "--cleaner",
+        type=str_or_none,
+        choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
+        default=None,
+        help="Apply text cleaning",
+    )
+    parser.add_argument(
+        "--g2p",
+        type=str_or_none,
+        choices=g2p_choices,
+        default=None,
+        help="Specify g2p method if --token_type=phn",
+    )
+
+    group = parser.add_argument_group("write_vocabulary mode related")
+    group.add_argument(
+        "--write_vocabulary",
+        type=str2bool,
+        default=False,
+        help="Write tokens list instead of tokenized text per line",
+    )
+    group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
+    group.add_argument(
+        "--cutoff",
+        default=0,
+        type=int,
+        help="cut-off frequency used for write-vocabulary mode",
+    )
+    group.add_argument(
+        "--add_symbol",
+        type=str,
+        default=[],
+        action="append",
+        help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
+    )
+
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    tokenize(**kwargs)
+
+
+if __name__ == "__main__":
+    main()

+ 73 - 0
funasr/datasets/preprocessor.py

@@ -58,6 +58,15 @@ def seg_tokenize(txt, seg_dict):
             continue
     return out_txt.strip().split()
 
+def seg_tokenize_wo_pattern(txt, seg_dict):
+    out_txt = ""
+    for word in txt:
+        if word in seg_dict:
+            out_txt += seg_dict[word] + " "
+        else:
+            out_txt += "<unk>" + " "
+    return out_txt.strip().split()
+
 
 def framing(
         x,
@@ -372,6 +381,70 @@ class CommonPreprocessor(AbsPreprocessor):
         data = self._text_process(data)
         return data
 
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: str = "text",
+            split_with_space: bool = False,
+            seg_dict_file: str = None,
+    ):
+        super().__init__(train,
+                         token_type,
+                         token_list,
+                         bpemodel,
+                         text_cleaner,
+                         g2p_type,
+                         unk_symbol,
+                         space_symbol,
+                         non_linguistic_symbols,
+                         delimiter,
+                         rir_scp,
+                         rir_apply_prob,
+                         noise_scp,
+                         noise_apply_prob,
+                         noise_db_range,
+                         speech_volume_normalize,
+                         speech_name,
+                         text_name,
+                         split_with_space,
+                         seg_dict_file,
+                         )
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        if self.text_name in data and self.tokenizer is not None:
+            text = data[self.text_name]
+            text = self.text_cleaner(text)
+            if self.split_with_space:
+                tokens = text.strip().split(" ")
+                if self.seg_dict is not None:
+                    tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+            else:
+                tokens = self.tokenizer.text2tokens(text)
+            text_ints = self.token_id_converter.tokens2ids(tokens)
+            data[self.text_name] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
 
 class CommonPreprocessor_multi(AbsPreprocessor):
     def __init__(

+ 2 - 2
funasr/lm/espnet_model.py

@@ -46,10 +46,10 @@ class ESPnetLanguageModel(AbsESPnetModel):
 
         # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
         # text: (Batch, Length) -> x, y: (Batch, Length + 1)
-        x = F.pad(text, [1, 0], "constant", self.eos)
+        x = F.pad(text, [1, 0], "constant", self.sos)
         t = F.pad(text, [0, 1], "constant", self.ignore_id)
         for i, l in enumerate(text_lengths):
-            t[i, l] = self.sos
+            t[i, l] = self.eos
         x_lengths = text_lengths + 1
 
         # 2. Forward Language model

+ 47 - 0
funasr/tasks/abs_task.py

@@ -43,6 +43,7 @@ from funasr.iterators.abs_iter_factory import AbsIterFactory
 from funasr.iterators.chunk_iter_factory import ChunkIterFactory
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
+from funasr.main_funcs.collect_stats import collect_stats
 from funasr.optimizers.sgd import SGD
 from funasr.optimizers.fairseq_adam import FairseqAdam
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
@@ -1272,6 +1273,52 @@ class AbsTask(ABC):
 
         if args.dry_run:
             pass
+        elif args.collect_stats:
+            # Perform on collect_stats mode. This mode has two roles
+            # - Derive the length and dimension of all input data
+            # - Accumulate feats, square values, and the length for whitening
+
+            if args.valid_batch_size is None:
+                args.valid_batch_size = args.batch_size
+
+            if len(args.train_shape_file) != 0:
+                train_key_file = args.train_shape_file[0]
+            else:
+                train_key_file = None
+            if len(args.valid_shape_file) != 0:
+                valid_key_file = args.valid_shape_file[0]
+            else:
+                valid_key_file = None
+
+            collect_stats(
+                model=model,
+                train_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.train_data_path_and_name_and_type,
+                    key_file=train_key_file,
+                    batch_size=args.batch_size,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                valid_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
+                    key_file=valid_key_file,
+                    batch_size=args.valid_batch_size,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                output_dir=output_dir,
+                ngpu=args.ngpu,
+                log_interval=args.log_interval,
+                write_collected_feats=args.write_collected_feats,
+            )
         else:
             logging.info("Training args: {}".format(args))
             # 6. Loads pre-trained model

+ 1 - 1
funasr/tasks/lm.py

@@ -58,7 +58,7 @@ class LMTask(AbsTask):
         # NOTE(kamo): add_arguments(..., required=True) can't be used
         # to provide --print_config mode. Instead of it, do as
         required = parser.get_default("required")
-        required += ["token_list"]
+        # required += ["token_list"]
 
         group.add_argument(
             "--token_list",