| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- import argparse
- import logging
- from typing import Callable
- from typing import Collection
- from typing import Dict
- from typing import List
- from typing import Optional
- from typing import Tuple
- import numpy as np
- import torch
- from typeguard import check_argument_types
- from typeguard import check_return_type
- from funasr.datasets.collate_fn import CommonCollateFn
- from funasr.datasets.preprocessor import CommonPreprocessor
- from funasr.train.abs_model import AbsLM
- from funasr.train.abs_model import LanguageModel
- from funasr.models.seq_rnn_lm import SequentialRNNLM
- from funasr.models.transformer_lm import TransformerLM
- from funasr.tasks.abs_task import AbsTask
- from funasr.text.phoneme_tokenizer import g2p_choices
- from funasr.torch_utils.initialize import initialize
- from funasr.train.class_choices import ClassChoices
- from funasr.train.trainer import Trainer
- from funasr.utils.get_default_kwargs import get_default_kwargs
- from funasr.utils.nested_dict_action import NestedDictAction
- from funasr.utils.types import str2bool
- from funasr.utils.types import str_or_none
- lm_choices = ClassChoices(
- "lm",
- classes=dict(
- seq_rnn=SequentialRNNLM,
- transformer=TransformerLM,
- ),
- type_check=AbsLM,
- default="seq_rnn",
- )
- class LMTask(AbsTask):
- # If you need more than one optimizers, change this value
- num_optimizers: int = 1
- # Add variable objects configurations
- class_choices_list = [lm_choices]
- # If you need to modify train() or eval() procedures, change Trainer class here
- trainer = Trainer
- @classmethod
- def add_task_arguments(cls, parser: argparse.ArgumentParser):
- # NOTE(kamo): Use '_' instead of '-' to avoid confusion
- assert check_argument_types()
- group = parser.add_argument_group(description="Task related")
- # NOTE(kamo): add_arguments(..., required=True) can't be used
- # to provide --print_config mode. Instead of it, do as
- required = parser.get_default("required")
- # required += ["token_list"]
- group.add_argument(
- "--token_list",
- type=str_or_none,
- default=None,
- help="A text mapping int-id to token",
- )
- group.add_argument(
- "--init",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- help="The initialization method",
- choices=[
- "chainer",
- "xavier_uniform",
- "xavier_normal",
- "kaiming_uniform",
- "kaiming_normal",
- None,
- ],
- )
- group.add_argument(
- "--model_conf",
- action=NestedDictAction,
- default=get_default_kwargs(LanguageModel),
- help="The keyword arguments for model class.",
- )
- group = parser.add_argument_group(description="Preprocess related")
- group.add_argument(
- "--use_preprocessor",
- type=str2bool,
- default=True,
- help="Apply preprocessing to data or not",
- )
- group.add_argument(
- "--token_type",
- type=str,
- default="bpe",
- choices=["bpe", "char", "word"],
- help="",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model file fo sentencepiece",
- )
- parser.add_argument(
- "--non_linguistic_symbols",
- type=str_or_none,
- help="non_linguistic_symbols file path",
- )
- parser.add_argument(
- "--cleaner",
- type=str_or_none,
- choices=[None, "tacotron", "jaconv", "vietnamese"],
- default=None,
- help="Apply text cleaning",
- )
- parser.add_argument(
- "--g2p",
- type=str_or_none,
- choices=g2p_choices,
- default=None,
- help="Specify g2p method if --token_type=phn",
- )
- for class_choices in cls.class_choices_list:
- class_choices.add_arguments(group)
- assert check_return_type(parser)
- return parser
- @classmethod
- def build_collate_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Callable[
- [Collection[Tuple[str, Dict[str, np.ndarray]]]],
- Tuple[List[str], Dict[str, torch.Tensor]],
- ]:
- assert check_argument_types()
- return CommonCollateFn(int_pad_value=0)
- @classmethod
- def build_preprocess_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- assert check_argument_types()
- if args.use_preprocessor:
- retval = CommonPreprocessor(
- train=train,
- token_type=args.token_type,
- token_list=args.token_list,
- bpemodel=args.bpemodel,
- text_cleaner=args.cleaner,
- g2p_type=args.g2p,
- non_linguistic_symbols=args.non_linguistic_symbols,
- )
- else:
- retval = None
- assert check_return_type(retval)
- return retval
- @classmethod
- def required_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- retval = ("text",)
- return retval
- @classmethod
- def optional_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- retval = ()
- return retval
- @classmethod
- def build_model(cls, args: argparse.Namespace) -> LanguageModel:
- assert check_argument_types()
- if isinstance(args.token_list, str):
- with open(args.token_list, encoding="utf-8") as f:
- token_list = [line.rstrip() for line in f]
- # "args" is saved as it is in a yaml file by BaseTask.main().
- # Overwriting token_list to keep it as "portable".
- args.token_list = token_list.copy()
- elif isinstance(args.token_list, (tuple, list)):
- token_list = args.token_list.copy()
- else:
- raise RuntimeError("token_list must be str or dict")
- vocab_size = len(token_list)
- logging.info(f"Vocabulary size: {vocab_size}")
- # 1. Build LM model
- lm_class = lm_choices.get_class(args.lm)
- lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
- # 2. Build ESPnetModel
- # Assume the last-id is sos_and_eos
- model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
- # 3. Initialize
- if args.init is not None:
- initialize(model, args.init)
- return model
|