| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965 |
- # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Abstract task module."""
- import argparse
- import functools
- import logging
- import os
- import sys
- from abc import ABC
- from abc import abstractmethod
- from dataclasses import dataclass
- from distutils.version import LooseVersion
- from io import BytesIO
- from pathlib import Path
- from typing import Any
- from typing import Callable
- from typing import Dict
- from typing import List
- from typing import Optional
- from typing import Sequence
- from typing import Tuple
- from typing import Union
- import humanfriendly
- import numpy as np
- import torch
- import torch.distributed as dist
- import torch.multiprocessing
- import torch.nn
- import torch.optim
- import yaml
- from funasr.models.base_model import FunASRModel
- from torch.utils.data import DataLoader
- from funasr import __version__
- from funasr.datasets.dataset import AbsDataset
- from funasr.datasets.dataset import DATA_TYPES
- from funasr.datasets.dataset import ESPnetDataset
- from funasr.datasets.iterable_dataset import IterableESPnetDataset
- from funasr.iterators.abs_iter_factory import AbsIterFactory
- from funasr.iterators.chunk_iter_factory import ChunkIterFactory
- from funasr.iterators.multiple_iter_factory import MultipleIterFactory
- from funasr.iterators.sequence_iter_factory import SequenceIterFactory
- from funasr.main_funcs.collect_stats import collect_stats
- from funasr.optimizers.fairseq_adam import FairseqAdam
- from funasr.optimizers.sgd import SGD
- from funasr.samplers.build_batch_sampler import BATCH_TYPES
- from funasr.samplers.build_batch_sampler import build_batch_sampler
- from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
- from funasr.schedulers.noam_lr import NoamLR
- from funasr.schedulers.tri_stage_scheduler import TriStageLR
- from funasr.schedulers.warmup_lr import WarmupLR
- from funasr.torch_utils.load_pretrained_model import load_pretrained_model
- from funasr.torch_utils.model_summary import model_summary
- from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
- from funasr.torch_utils.set_all_random_seed import set_all_random_seed
- from funasr.train.class_choices import ClassChoices
- from funasr.train.distributed_utils import DistributedOption
- from funasr.train.trainer import Trainer
- from funasr.utils import config_argparse
- from funasr.utils.build_dataclass import build_dataclass
- from funasr.utils.cli_utils import get_commandline_args
- from funasr.utils.get_default_kwargs import get_default_kwargs
- from funasr.utils.nested_dict_action import NestedDictAction
- from funasr.utils.types import humanfriendly_parse_size_or_none
- from funasr.utils.types import int_or_none
- from funasr.utils.types import str2bool
- from funasr.utils.types import str2triple_str
- from funasr.utils.types import str_or_int
- from funasr.utils.types import str_or_none
- from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
- from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
- from funasr.modules.lora.utils import mark_only_lora_as_trainable
- try:
- import wandb
- except Exception:
- wandb = None
- if LooseVersion(torch.__version__) >= LooseVersion("1.5.0"):
- pass
- else:
- pass
- optim_classes = dict(
- adam=torch.optim.Adam,
- fairseq_adam=FairseqAdam,
- adamw=torch.optim.AdamW,
- sgd=SGD,
- adadelta=torch.optim.Adadelta,
- adagrad=torch.optim.Adagrad,
- adamax=torch.optim.Adamax,
- asgd=torch.optim.ASGD,
- lbfgs=torch.optim.LBFGS,
- rmsprop=torch.optim.RMSprop,
- rprop=torch.optim.Rprop,
- )
- if LooseVersion(torch.__version__) >= LooseVersion("1.10.0"):
- # From 1.10.0, RAdam is officially supported
- optim_classes.update(
- radam=torch.optim.RAdam,
- )
- try:
- import torch_optimizer
- optim_classes.update(
- accagd=torch_optimizer.AccSGD,
- adabound=torch_optimizer.AdaBound,
- adamod=torch_optimizer.AdaMod,
- diffgrad=torch_optimizer.DiffGrad,
- lamb=torch_optimizer.Lamb,
- novograd=torch_optimizer.NovoGrad,
- pid=torch_optimizer.PID,
- # torch_optimizer<=0.0.1a10 doesn't support
- # qhadam=torch_optimizer.QHAdam,
- qhm=torch_optimizer.QHM,
- sgdw=torch_optimizer.SGDW,
- yogi=torch_optimizer.Yogi,
- )
- if LooseVersion(torch_optimizer.__version__) < LooseVersion("0.2.0"):
- # From 0.2.0, RAdam is dropped
- optim_classes.update(
- radam=torch_optimizer.RAdam,
- )
- del torch_optimizer
- except ImportError:
- pass
- try:
- import apex
- optim_classes.update(
- fusedadam=apex.optimizers.FusedAdam,
- fusedlamb=apex.optimizers.FusedLAMB,
- fusednovograd=apex.optimizers.FusedNovoGrad,
- fusedsgd=apex.optimizers.FusedSGD,
- )
- del apex
- except ImportError:
- pass
- try:
- import fairscale
- except ImportError:
- fairscale = None
- scheduler_classes = dict(
- ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
- lambdalr=torch.optim.lr_scheduler.LambdaLR,
- steplr=torch.optim.lr_scheduler.StepLR,
- multisteplr=torch.optim.lr_scheduler.MultiStepLR,
- exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
- CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
- noamlr=NoamLR,
- warmuplr=WarmupLR,
- tri_stage=TriStageLR,
- cycliclr=torch.optim.lr_scheduler.CyclicLR,
- onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
- CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
- )
- # To lower keys
- optim_classes = {k.lower(): v for k, v in optim_classes.items()}
- scheduler_classes = {k.lower(): v for k, v in scheduler_classes.items()}
- @dataclass
- class IteratorOptions:
- preprocess_fn: callable
- collate_fn: callable
- data_path_and_name_and_type: list
- shape_files: list
- batch_size: int
- batch_bins: int
- batch_type: str
- max_cache_size: float
- max_cache_fd: int
- distributed: bool
- num_batches: Optional[int]
- num_iters_per_epoch: Optional[int]
- train: bool
- class AbsTask(ABC):
- # Use @staticmethod, or @classmethod,
- # instead of instance method to avoid God classes
- # If you need more than one optimizers, change this value in inheritance
- num_optimizers: int = 1
- trainer = Trainer
- class_choices_list: List[ClassChoices] = []
- finetune_args: None
- def __init__(self):
- raise RuntimeError("This class can't be instantiated.")
- @classmethod
- @abstractmethod
- def add_task_arguments(cls, parser: argparse.ArgumentParser):
- pass
- @classmethod
- @abstractmethod
- def build_collate_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Callable[[Sequence[Dict[str, np.ndarray]]], Dict[str, torch.Tensor]]:
- """Return "collate_fn", which is a callable object and given to DataLoader.
- >>> from torch.utils.data import DataLoader
- >>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...)
- In many cases, you can use our common collate_fn.
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def build_preprocess_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def required_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- """Define the required names by Task
- This function is used by
- >>> cls.check_task_requirements()
- If your model is defined as following,
- >>> from funasr.models.base_model import FunASRModel
- >>> class Model(FunASRModel):
- ... def forward(self, input, output, opt=None): pass
- then "required_data_names" should be as
- >>> required_data_names = ('input', 'output')
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def optional_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- """Define the optional names by Task
- This function is used by
- >>> cls.check_task_requirements()
- If your model is defined as follows,
- >>> from funasr.models.base_model import FunASRModel
- >>> class Model(FunASRModel):
- ... def forward(self, input, output, opt=None): pass
- then "optional_data_names" should be as
- >>> optional_data_names = ('opt',)
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def build_model(cls, args: argparse.Namespace) -> FunASRModel:
- raise NotImplementedError
- @classmethod
- def get_parser(cls) -> config_argparse.ArgumentParser:
- class ArgumentDefaultsRawTextHelpFormatter(
- argparse.RawTextHelpFormatter,
- argparse.ArgumentDefaultsHelpFormatter,
- ):
- pass
- parser = config_argparse.ArgumentParser(
- description="base parser",
- formatter_class=ArgumentDefaultsRawTextHelpFormatter,
- )
- # NOTE(kamo): Use '_' instead of '-' to avoid confusion.
- # I think '-' looks really confusing if it's written in yaml.
- # NOTE(kamo): add_arguments(..., required=True) can't be used
- # to provide --print_config mode. Instead of it, do as
- # parser.set_defaults(required=["output_dir"])
- group = parser.add_argument_group("Common configuration")
- group.add_argument(
- "--print_config",
- action="store_true",
- help="Print the config file and exit",
- )
- group.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
- group.add_argument(
- "--dry_run",
- type=str2bool,
- default=False,
- help="Perform process without training",
- )
- group.add_argument(
- "--iterator_type",
- type=str,
- choices=["sequence", "chunk", "task", "none"],
- default="sequence",
- help="Specify iterator type",
- )
- group.add_argument("--output_dir", type=str_or_none, default=None)
- group.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- group.add_argument("--seed", type=int, default=0, help="Random seed")
- group.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- group.add_argument(
- "--num_att_plot",
- type=int,
- default=3,
- help="The number images to plot the outputs from attention. "
- "This option makes sense only when attention-based model. "
- "We can also disable the attention plot by setting it 0",
- )
- group = parser.add_argument_group("distributed training related")
- group.add_argument(
- "--dist_backend",
- default="nccl",
- type=str,
- help="distributed backend",
- )
- group.add_argument(
- "--dist_init_method",
- type=str,
- default="env://",
- help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
- '"WORLD_SIZE", and "RANK" are referred.',
- )
- group.add_argument(
- "--dist_world_size",
- default=None,
- type=int_or_none,
- help="number of nodes for distributed training",
- )
- group.add_argument(
- "--dist_rank",
- type=int_or_none,
- default=None,
- help="node rank for distributed training",
- )
- group.add_argument(
- # Not starting with "dist_" for compatibility to launch.py
- "--local_rank",
- type=int_or_none,
- default=None,
- help="local rank for distributed training. This option is used if "
- "--multiprocessing_distributed=false",
- )
- group.add_argument(
- "--dist_master_addr",
- default=None,
- type=str_or_none,
- help="The master address for distributed training. "
- "This value is used when dist_init_method == 'env://'",
- )
- group.add_argument(
- "--dist_master_port",
- default=None,
- type=int_or_none,
- help="The master port for distributed training"
- "This value is used when dist_init_method == 'env://'",
- )
- group.add_argument(
- "--dist_launcher",
- default=None,
- type=str_or_none,
- choices=["slurm", "mpi", None],
- help="The launcher type for distributed training",
- )
- group.add_argument(
- "--multiprocessing_distributed",
- default=False,
- type=str2bool,
- help="Use multi-processing distributed training to launch "
- "N processes per node, which has N GPUs. This is the "
- "fastest way to use PyTorch for either single node or "
- "multi node data parallel training",
- )
- group.add_argument(
- "--unused_parameters",
- type=str2bool,
- default=False,
- help="Whether to use the find_unused_parameters in "
- "torch.nn.parallel.DistributedDataParallel ",
- )
- group.add_argument(
- "--sharded_ddp",
- default=False,
- type=str2bool,
- help="Enable sharded training provided by fairscale",
- )
- group = parser.add_argument_group("cudnn mode related")
- group.add_argument(
- "--cudnn_enabled",
- type=str2bool,
- default=torch.backends.cudnn.enabled,
- help="Enable CUDNN",
- )
- group.add_argument(
- "--cudnn_benchmark",
- type=str2bool,
- default=torch.backends.cudnn.benchmark,
- help="Enable cudnn-benchmark mode",
- )
- group.add_argument(
- "--cudnn_deterministic",
- type=str2bool,
- default=True,
- help="Enable cudnn-deterministic mode",
- )
- group = parser.add_argument_group("collect stats mode related")
- group.add_argument(
- "--collect_stats",
- type=str2bool,
- default=False,
- help='Perform on "collect stats" mode',
- )
- group.add_argument(
- "--mc",
- type=bool,
- default=False,
- help="MultiChannel input",
- )
- group.add_argument(
- "--write_collected_feats",
- type=str2bool,
- default=False,
- help='Write the output features from the model when "collect stats" mode',
- )
- group = parser.add_argument_group("Trainer related")
- group.add_argument(
- "--max_epoch",
- type=int,
- default=40,
- help="The maximum number epoch to train",
- )
- group.add_argument(
- "--max_update",
- type=int,
- default=sys.maxsize,
- help="The maximum number update step to train",
- )
- parser.add_argument(
- "--batch_interval",
- type=int,
- default=-1,
- help="The batch interval for saving model.",
- )
- group.add_argument(
- "--patience",
- type=int_or_none,
- default=None,
- help="Number of epochs to wait without improvement "
- "before stopping the training",
- )
- group.add_argument(
- "--val_scheduler_criterion",
- type=str,
- nargs=2,
- default=("valid", "loss"),
- help="The criterion used for the value given to the lr scheduler. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'and the criterion name. The mode specifying "min" or "max" can '
- "be changed by --scheduler_conf",
- )
- group.add_argument(
- "--early_stopping_criterion",
- type=str,
- nargs=3,
- default=("valid", "loss", "min"),
- help="The criterion used for judging of early stopping. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
- )
- group.add_argument(
- "--best_model_criterion",
- type=str2triple_str,
- nargs="+",
- default=[
- ("train", "loss", "min"),
- ("valid", "loss", "min"),
- ("train", "acc", "max"),
- ("valid", "acc", "max"),
- ],
- help="The criterion used for judging of the best model. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
- )
- group.add_argument(
- "--keep_nbest_models",
- type=int,
- nargs="+",
- default=[10],
- help="Remove previous snapshots excluding the n-best scored epochs",
- )
- group.add_argument(
- "--nbest_averaging_interval",
- type=int,
- default=0,
- help="The epoch interval to apply model averaging and save nbest models",
- )
- group.add_argument(
- "--grad_clip",
- type=float,
- default=5.0,
- help="Gradient norm threshold to clip",
- )
- group.add_argument(
- "--grad_clip_type",
- type=float,
- default=2.0,
- help="The type of the used p-norm for gradient clip. Can be inf",
- )
- group.add_argument(
- "--grad_noise",
- type=str2bool,
- default=False,
- help="The flag to switch to use noise injection to "
- "gradients during training",
- )
- group.add_argument(
- "--accum_grad",
- type=int,
- default=1,
- help="The number of gradient accumulation",
- )
- group.add_argument(
- "--bias_grad_times",
- type=float,
- default=1.0,
- help="To scale the gradient of contextual related params",
- )
- group.add_argument(
- "--no_forward_run",
- type=str2bool,
- default=False,
- help="Just only iterating data loading without "
- "model forwarding and training",
- )
- group.add_argument(
- "--resume",
- type=str2bool,
- default=False,
- help="Enable resuming if checkpoint is existing",
- )
- group.add_argument(
- "--train_dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type for training.",
- )
- group.add_argument(
- "--use_amp",
- type=str2bool,
- default=False,
- help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
- )
- group.add_argument(
- "--log_interval",
- type=int_or_none,
- default=None,
- help="Show the logs every the number iterations in each epochs at the "
- "training phase. If None is given, it is decided according the number "
- "of training samples automatically .",
- )
- group.add_argument(
- "--use_tensorboard",
- type=str2bool,
- default=True,
- help="Enable tensorboard logging",
- )
- group.add_argument(
- "--use_wandb",
- type=str2bool,
- default=False,
- help="Enable wandb logging",
- )
- group.add_argument(
- "--wandb_project",
- type=str,
- default=None,
- help="Specify wandb project",
- )
- group.add_argument(
- "--wandb_id",
- type=str,
- default=None,
- help="Specify wandb id",
- )
- group.add_argument(
- "--wandb_entity",
- type=str,
- default=None,
- help="Specify wandb entity",
- )
- group.add_argument(
- "--wandb_name",
- type=str,
- default=None,
- help="Specify wandb run name",
- )
- group.add_argument(
- "--wandb_model_log_interval",
- type=int,
- default=-1,
- help="Set the model log period",
- )
- group.add_argument(
- "--detect_anomaly",
- type=str2bool,
- default=False,
- help="Set torch.autograd.set_detect_anomaly",
- )
- group = parser.add_argument_group("Pretraining model related")
- group.add_argument("--pretrain_path", help="This option is obsoleted")
- group.add_argument(
- "--init_param",
- type=str,
- action="append",
- default=[],
- help="Specify the file path used for initialization of parameters. "
- "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
- "where file_path is the model file path, "
- "src_key specifies the key of model states to be used in the model file, "
- "dst_key specifies the attribute of the model to be initialized, "
- "and exclude_keys excludes keys of model states for the initialization."
- "e.g.\n"
- " # Load all parameters"
- " --init_param some/where/model.pb\n"
- " # Load only decoder parameters"
- " --init_param some/where/model.pb:decoder:decoder\n"
- " # Load only decoder parameters excluding decoder.embed"
- " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
- " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
- )
- group.add_argument(
- "--ignore_init_mismatch",
- type=str2bool,
- default=False,
- help="Ignore size mismatch when loading pre-trained model",
- )
- group.add_argument(
- "--freeze_param",
- type=str,
- default=[],
- action="append",
- help="Freeze parameters",
- )
- group = parser.add_argument_group("BatchSampler related")
- group.add_argument(
- "--num_iters_per_epoch",
- type=int_or_none,
- default=None,
- help="Restrict the number of iterations for training per epoch",
- )
- group.add_argument(
- "--batch_size",
- type=int,
- default=20,
- help="The mini-batch size used for training. Used if batch_type='unsorted',"
- " 'sorted', or 'folded'.",
- )
- group.add_argument(
- "--valid_batch_size",
- type=int_or_none,
- default=None,
- help="If not given, the value of --batch_size is used",
- )
- group.add_argument(
- "--batch_bins",
- type=int,
- default=1000000,
- help="The number of batch bins. Used if batch_type='length' or 'numel'",
- )
- group.add_argument(
- "--valid_batch_bins",
- type=int_or_none,
- default=None,
- help="If not given, the value of --batch_bins is used",
- )
- group.add_argument("--train_shape_file", type=str, action="append", default=[])
- group.add_argument("--valid_shape_file", type=str, action="append", default=[])
- group = parser.add_argument_group("Sequence iterator related")
- _batch_type_help = ""
- for key, value in BATCH_TYPES.items():
- _batch_type_help += f'"{key}":\n{value}\n'
- group.add_argument(
- "--batch_type",
- type=str,
- default="length",
- choices=list(BATCH_TYPES),
- help=_batch_type_help,
- )
- group.add_argument(
- "--valid_batch_type",
- type=str_or_none,
- default=None,
- choices=list(BATCH_TYPES) + [None],
- help="If not given, the value of --batch_type is used",
- )
- group.add_argument(
- "--speech_length_min",
- type=int,
- default=-1,
- help="speech length min",
- )
- group.add_argument(
- "--speech_length_max",
- type=int,
- default=-1,
- help="speech length max",
- )
- group.add_argument("--fold_length", type=int, action="append", default=[])
- group.add_argument(
- "--sort_in_batch",
- type=str,
- default="descending",
- choices=["descending", "ascending"],
- help="Sort the samples in each mini-batches by the sample "
- 'lengths. To enable this, "shape_file" must have the length information.',
- )
- group.add_argument(
- "--sort_batch",
- type=str,
- default="descending",
- choices=["descending", "ascending"],
- help="Sort mini-batches by the sample lengths",
- )
- group.add_argument(
- "--multiple_iterator",
- type=str2bool,
- default=False,
- help="Use multiple iterator mode",
- )
- group = parser.add_argument_group("Chunk iterator related")
- group.add_argument(
- "--chunk_length",
- type=str_or_int,
- default=500,
- help="Specify chunk length. e.g. '300', '300,400,500', or '300-400'."
- "If multiple numbers separated by command are given, "
- "one of them is selected randomly for each samples. "
- "If two numbers are given with '-', it indicates the range of the choices. "
- "Note that if the sequence length is shorter than the all chunk_lengths, "
- "the sample is discarded. ",
- )
- group.add_argument(
- "--chunk_shift_ratio",
- type=float,
- default=0.5,
- help="Specify the shift width of chunks. If it's less than 1, "
- "allows the overlapping and if bigger than 1, there are some gaps "
- "between each chunk.",
- )
- group.add_argument(
- "--num_cache_chunks",
- type=int,
- default=1024,
- help="Shuffle in the specified number of chunks and generate mini-batches "
- "More larger this value, more randomness can be obtained.",
- )
- group = parser.add_argument_group("Dataset related")
- _data_path_and_name_and_type_help = (
- "Give three words splitted by comma. It's used for the training data. "
- "e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. "
- "The first value, some/path/a.scp, indicates the file path, "
- "and the second, foo, is the key name used for the mini-batch data, "
- "and the last, sound, decides the file type. "
- "This option is repeatable, so you can input any number of features "
- "for your task. Supported file types are as follows:\n\n"
- )
- for key, dic in DATA_TYPES.items():
- _data_path_and_name_and_type_help += f'"{key}":\n{dic["help"]}\n\n'
- # for large dataset
- group.add_argument(
- "--dataset_type",
- type=str,
- default="small",
- help="whether to use dataloader for large dataset",
- )
- parser.add_argument(
- "--dataset_conf",
- action=NestedDictAction,
- default=dict(),
- help=f"The keyword arguments for dataset",
- )
- group.add_argument(
- "--train_data_file",
- type=str,
- default=None,
- help="train_list for large dataset",
- )
- group.add_argument(
- "--valid_data_file",
- type=str,
- default=None,
- help="valid_list for large dataset",
- )
- group.add_argument(
- "--train_data_path_and_name_and_type",
- type=str2triple_str,
- action="append",
- default=[],
- help=_data_path_and_name_and_type_help,
- )
- group.add_argument(
- "--valid_data_path_and_name_and_type",
- type=str2triple_str,
- action="append",
- default=[],
- )
- group.add_argument(
- "--allow_variable_data_keys",
- type=str2bool,
- default=False,
- help="Allow the arbitrary keys for mini-batch with ignoring "
- "the task requirements",
- )
- group.add_argument(
- "--max_cache_size",
- type=humanfriendly.parse_size,
- default=0.0,
- help="The maximum cache size for data loader. e.g. 10MB, 20GB.",
- )
- group.add_argument(
- "--max_cache_fd",
- type=int,
- default=32,
- help="The maximum number of file descriptors to be kept "
- "as opened for ark files. "
- "This feature is only valid when data type is 'kaldi_ark'.",
- )
- group.add_argument(
- "--valid_max_cache_size",
- type=humanfriendly_parse_size_or_none,
- default=None,
- help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. "
- "If None, the 5 percent size of --max_cache_size",
- )
- group = parser.add_argument_group("Optimizer related")
- for i in range(1, cls.num_optimizers + 1):
- suf = "" if i == 1 else str(i)
- group.add_argument(
- f"--optim{suf}",
- type=lambda x: x.lower(),
- default="adadelta",
- choices=list(optim_classes),
- help="The optimizer type",
- )
- group.add_argument(
- f"--optim{suf}_conf",
- action=NestedDictAction,
- default=dict(),
- help="The keyword arguments for optimizer",
- )
- group.add_argument(
- f"--scheduler{suf}",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- choices=list(scheduler_classes) + [None],
- help="The lr scheduler type",
- )
- group.add_argument(
- f"--scheduler{suf}_conf",
- action=NestedDictAction,
- default=dict(),
- help="The keyword arguments for lr scheduler",
- )
- # for training on PAI
- group = parser.add_argument_group("PAI training related")
- group.add_argument(
- "--use_pai",
- type=str2bool,
- default=False,
- help="flag to indicate whether training on PAI",
- )
- group.add_argument(
- "--simple_ddp",
- type=str2bool,
- default=False,
- )
- group.add_argument(
- "--num_worker_count",
- type=int,
- default=1,
- help="The number of machines on PAI.",
- )
- group.add_argument(
- "--access_key_id",
- type=str,
- default=None,
- help="The username for oss.",
- )
- group.add_argument(
- "--access_key_secret",
- type=str,
- default=None,
- help="The password for oss.",
- )
- group.add_argument(
- "--endpoint",
- type=str,
- default=None,
- help="The endpoint for oss.",
- )
- group.add_argument(
- "--bucket_name",
- type=str,
- default=None,
- help="The bucket name for oss.",
- )
- group.add_argument(
- "--oss_bucket",
- default=None,
- help="oss bucket.",
- )
- group.add_argument(
- "--enable_lora",
- type=str2bool,
- default=False,
- help="Apply lora for finetuning.",
- )
- group.add_argument(
- "--lora_bias",
- type=str,
- default="none",
- help="lora bias.",
- )
- cls.trainer.add_arguments(parser)
- cls.add_task_arguments(parser)
- return parser
- @classmethod
- def build_optimizers(
- cls,
- args: argparse.Namespace,
- model: torch.nn.Module,
- ) -> List[torch.optim.Optimizer]:
- if cls.num_optimizers != 1:
- raise RuntimeError(
- "build_optimizers() must be overridden if num_optimizers != 1"
- )
- optim_class = optim_classes.get(args.optim)
- if optim_class is None:
- raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
- if args.sharded_ddp:
- if fairscale is None:
- raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'")
- optim = fairscale.optim.oss.OSS(
- params=model.parameters(), optim=optim_class, **args.optim_conf
- )
- else:
- optim = optim_class(model.parameters(), **args.optim_conf)
- optimizers = [optim]
- return optimizers
- @classmethod
- def exclude_opts(cls) -> Tuple[str, ...]:
- """The options not to be shown by --print_config"""
- return "required", "print_config", "config", "ngpu"
- @classmethod
- def get_default_config(cls) -> Dict[str, Any]:
- """Return the configuration as dict.
- This method is used by print_config()
- """
- def get_class_type(name: str, classes: dict):
- _cls = classes.get(name)
- if _cls is None:
- raise ValueError(f"must be one of {list(classes)}: {name}")
- return _cls
- # This method is used only for --print_config
- parser = cls.get_parser()
- args, _ = parser.parse_known_args()
- config = vars(args)
- # Excludes the options not to be shown
- for k in AbsTask.exclude_opts():
- config.pop(k)
- for i in range(1, cls.num_optimizers + 1):
- suf = "" if i == 1 else str(i)
- name = config[f"optim{suf}"]
- optim_class = get_class_type(name, optim_classes)
- conf = get_default_kwargs(optim_class)
- # Overwrite the default by the arguments,
- conf.update(config[f"optim{suf}_conf"])
- # and set it again
- config[f"optim{suf}_conf"] = conf
- name = config[f"scheduler{suf}"]
- if name is not None:
- scheduler_class = get_class_type(name, scheduler_classes)
- conf = get_default_kwargs(scheduler_class)
- # Overwrite the default by the arguments,
- conf.update(config[f"scheduler{suf}_conf"])
- # and set it again
- config[f"scheduler{suf}_conf"] = conf
- for class_choices in cls.class_choices_list:
- if getattr(args, class_choices.name) is not None:
- class_obj = class_choices.get_class(getattr(args, class_choices.name))
- conf = get_default_kwargs(class_obj)
- name = class_choices.name
- # Overwrite the default by the arguments,
- conf.update(config[f"{name}_conf"])
- # and set it again
- config[f"{name}_conf"] = conf
- return config
- @classmethod
- def check_required_command_args(cls, args: argparse.Namespace):
- if hasattr(args, "required"):
- for k in vars(args):
- if "-" in k:
- raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
- required = ", ".join(
- f"--{a}" for a in args.required if getattr(args, a) is None
- )
- if len(required) != 0:
- parser = cls.get_parser()
- parser.print_help(file=sys.stderr)
- p = Path(sys.argv[0]).name
- print(file=sys.stderr)
- print(
- f"{p}: error: the following arguments are required: " f"{required}",
- file=sys.stderr,
- )
- sys.exit(2)
- @classmethod
- def check_task_requirements(
- cls,
- dataset: Union[AbsDataset, IterableESPnetDataset],
- allow_variable_data_keys: bool,
- train: bool,
- inference: bool = False,
- ) -> None:
- """Check if the dataset satisfy the requirement of current Task"""
- mes = (
- f"If you intend to use an additional input, modify "
- f'"{cls.__name__}.required_data_names()" or '
- f'"{cls.__name__}.optional_data_names()". '
- f"Otherwise you need to set --allow_variable_data_keys true "
- )
- for k in cls.required_data_names(train, inference):
- if not dataset.has_name(k):
- raise RuntimeError(
- f'"{cls.required_data_names(train, inference)}" are required for'
- f' {cls.__name__}. but "{dataset.names()}" are input.\n{mes}'
- )
- if not allow_variable_data_keys:
- task_keys = cls.required_data_names(
- train, inference
- ) + cls.optional_data_names(train, inference)
- for k in dataset.names():
- if k not in task_keys:
- raise RuntimeError(
- f"The data-name must be one of {task_keys} "
- f'for {cls.__name__}: "{k}" is not allowed.\n{mes}'
- )
- @classmethod
- def print_config(cls, file=sys.stdout) -> None:
- # Shows the config: e.g. python train.py asr --print_config
- config = cls.get_default_config()
- file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False))
- @classmethod
- def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
- print(get_commandline_args(), file=sys.stderr)
- if args is None:
- parser = cls.get_parser()
- args = parser.parse_args(cmd)
- args.version = __version__
- if args.pretrain_path is not None:
- raise RuntimeError("--pretrain_path is deprecated. Use --init_param")
- if args.print_config:
- cls.print_config()
- sys.exit(0)
- cls.check_required_command_args(args)
- if not args.distributed or not args.multiprocessing_distributed:
- cls.main_worker(args)
- else:
- assert args.ngpu > 1
- cls.main_worker(args)
- @classmethod
- def run(cls):
- assert hasattr(cls, "finetune_args")
- args = cls.finetune_args
- args.train_shape_file = None
- if args.distributed:
- args.simple_ddp = True
- else:
- args.simple_ddp = False
- args.ngpu = 1
- args.use_pai = False
- args.batch_type = "length"
- args.oss_bucket = None
- args.input_size = None
- cls.main_worker(args)
- @classmethod
- def main_worker(cls, args: argparse.Namespace):
- # 0. Init distributed process
- distributed_option = build_dataclass(DistributedOption, args)
- # Setting distributed_option.dist_rank, etc.
- if args.use_pai:
- distributed_option.init_options_pai()
- elif not args.simple_ddp:
- distributed_option.init_options()
- # Invoking torch.distributed.init_process_group
- if args.use_pai:
- distributed_option.init_torch_distributed_pai(args)
- elif not args.simple_ddp:
- distributed_option.init_torch_distributed(args)
- elif args.distributed and args.simple_ddp:
- distributed_option.init_torch_distributed_pai(args)
- args.ngpu = dist.get_world_size()
- if args.dataset_type == "small" and args.ngpu > 0:
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None and args.ngpu > 0:
- args.batch_bins = args.batch_bins * args.ngpu
- # filter samples if wav.scp and text are mismatch
- if (
- args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
- if not args.simple_ddp or distributed_option.dist_rank == 0:
- filter_wav_text(args.data_dir, args.train_set)
- filter_wav_text(args.data_dir, args.dev_set)
- if args.simple_ddp:
- dist.barrier()
- if args.train_shape_file is None and args.dataset_type == "small":
- if not args.simple_ddp or distributed_option.dist_rank == 0:
- calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
- args.speech_length_max)
- calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
- args.speech_length_max)
- if args.simple_ddp:
- dist.barrier()
- args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
- args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
- if args.train_data_file is None and args.dataset_type == "large":
- if not args.simple_ddp or distributed_option.dist_rank == 0:
- generate_data_list(args.data_dir, args.train_set)
- generate_data_list(args.data_dir, args.dev_set)
- if args.simple_ddp:
- dist.barrier()
- args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
- args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
- # NOTE(kamo): Don't use logging before invoking logging.basicConfig()
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- if not distributed_option.distributed:
- _rank = ""
- else:
- _rank = (
- f":{distributed_option.dist_rank}/"
- f"{distributed_option.dist_world_size}"
- )
- # NOTE(kamo):
- # logging.basicConfig() is invoked in main_worker() instead of main()
- # because it can be invoked only once in a process.
- # FIXME(kamo): Should we use logging.getLogger()?
- # BUGFIX: Remove previous handlers and reset log level
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
- logging.basicConfig(
- level=args.log_level,
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- else:
- # BUGFIX: Remove previous handlers and reset log level
- for handler in logging.root.handlers[:]:
- logging.root.removeHandler(handler)
- # Suppress logging if RANK != 0
- logging.basicConfig(
- level="ERROR",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
- distributed_option.dist_rank,
- distributed_option.local_rank))
- # 1. Set random-seed
- set_all_random_seed(args.seed)
- torch.backends.cudnn.enabled = args.cudnn_enabled
- torch.backends.cudnn.benchmark = args.cudnn_benchmark
- torch.backends.cudnn.deterministic = args.cudnn_deterministic
- if args.detect_anomaly:
- logging.info("Invoking torch.autograd.set_detect_anomaly(True)")
- torch.autograd.set_detect_anomaly(args.detect_anomaly)
- # 2. Build model
- model = cls.build_model(args=args)
- if not isinstance(model, FunASRModel):
- raise RuntimeError(
- f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
- )
- model = model.to(
- dtype=getattr(torch, args.train_dtype),
- device="cuda" if args.ngpu > 0 else "cpu",
- )
- if args.enable_lora:
- mark_only_lora_as_trainable(model, args.lora_bias)
- for t in args.freeze_param:
- for k, p in model.named_parameters():
- if k.startswith(t + ".") or k == t:
- logging.info(f"Setting {k}.requires_grad = False")
- p.requires_grad = False
- # 3. Build optimizer
- optimizers = cls.build_optimizers(args, model=model)
- # 4. Build schedulers
- schedulers = []
- for i, optim in enumerate(optimizers, 1):
- suf = "" if i == 1 else str(i)
- name = getattr(args, f"scheduler{suf}")
- conf = getattr(args, f"scheduler{suf}_conf")
- if name is not None:
- cls_ = scheduler_classes.get(name)
- if cls_ is None:
- raise ValueError(
- f"must be one of {list(scheduler_classes)}: {name}"
- )
- scheduler = cls_(optim, **conf)
- else:
- scheduler = None
- schedulers.append(scheduler)
- logging.info(pytorch_cudnn_version())
- logging.info(model_summary(model))
- for i, (o, s) in enumerate(zip(optimizers, schedulers), 1):
- suf = "" if i == 1 else str(i)
- logging.info(f"Optimizer{suf}:\n{o}")
- logging.info(f"Scheduler{suf}: {s}")
- # 5. Dump "args" to config.yaml
- # NOTE(kamo): "args" should be saved after object-buildings are done
- # because they are allowed to modify "args".
- output_dir = Path(args.output_dir)
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- output_dir.mkdir(parents=True, exist_ok=True)
- with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
- logging.info(
- f'Saving the configuration in {output_dir / "config.yaml"}'
- )
- if args.use_pai:
- buffer = BytesIO()
- torch.save({"config": vars(args)}, buffer)
- args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
- else:
- yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
- if args.dry_run:
- pass
- elif args.collect_stats:
- # Perform on collect_stats mode. This mode has two roles
- # - Derive the length and dimension of all input data
- # - Accumulate feats, square values, and the length for whitening
- if args.valid_batch_size is None:
- args.valid_batch_size = args.batch_size
- if len(args.train_shape_file) != 0:
- train_key_file = args.train_shape_file[0]
- else:
- train_key_file = None
- if len(args.valid_shape_file) != 0:
- valid_key_file = args.valid_shape_file[0]
- else:
- valid_key_file = None
- collect_stats(
- model=model,
- train_iter=cls.build_streaming_iterator(
- data_path_and_name_and_type=args.train_data_path_and_name_and_type,
- key_file=train_key_file,
- batch_size=args.batch_size,
- mc=args.mc,
- dtype=args.train_dtype,
- num_workers=args.num_workers,
- allow_variable_data_keys=args.allow_variable_data_keys,
- ngpu=args.ngpu,
- preprocess_fn=cls.build_preprocess_fn(args, train=False),
- collate_fn=cls.build_collate_fn(args, train=False),
- ),
- valid_iter=cls.build_streaming_iterator(
- data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
- key_file=valid_key_file,
- batch_size=args.valid_batch_size,
- mc=args.mc,
- dtype=args.train_dtype,
- num_workers=args.num_workers,
- allow_variable_data_keys=args.allow_variable_data_keys,
- ngpu=args.ngpu,
- preprocess_fn=cls.build_preprocess_fn(args, train=False),
- collate_fn=cls.build_collate_fn(args, train=False),
- ),
- output_dir=output_dir,
- ngpu=args.ngpu,
- log_interval=args.log_interval,
- write_collected_feats=args.write_collected_feats,
- )
- else:
- logging.info("Training args: {}".format(args))
- # 6. Loads pre-trained model
- for p in args.init_param:
- logging.info(f"Loading pretrained params from {p}")
- load_pretrained_model(
- model=model,
- init_param=p,
- ignore_init_mismatch=args.ignore_init_mismatch,
- # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
- # in PyTorch<=1.4
- map_location=f"cuda:{torch.cuda.current_device()}"
- if args.ngpu > 0
- else "cpu",
- oss_bucket=args.oss_bucket,
- )
- # 7. Build iterator factories
- if args.dataset_type == "large":
- from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
- train_iter_factory = LargeDataLoader(args, mode="train")
- valid_iter_factory = LargeDataLoader(args, mode="eval")
- elif args.dataset_type == "small":
- train_iter_factory = cls.build_iter_factory(
- args=args,
- distributed_option=distributed_option,
- mode="train",
- )
- valid_iter_factory = cls.build_iter_factory(
- args=args,
- distributed_option=distributed_option,
- mode="valid",
- )
- else:
- raise ValueError(f"Not supported dataset_type={args.dataset_type}")
- if args.scheduler == "tri_stage":
- for scheduler in schedulers:
- scheduler.init_tri_stage_scheudler(max_update=args.max_update)
- # 8. Start training
- if args.use_wandb:
- if wandb is None:
- raise RuntimeError("Please install wandb")
- try:
- wandb.login()
- except wandb.errors.UsageError:
- logging.info("wandb not configured! run `wandb login` to enable")
- args.use_wandb = False
- if args.use_wandb:
- if (
- not distributed_option.distributed
- or distributed_option.dist_rank == 0
- ):
- if args.wandb_project is None:
- project = "FunASR_" + cls.__name__
- else:
- project = args.wandb_project
- if args.wandb_name is None:
- name = str(Path(".").resolve()).replace("/", "_")
- else:
- name = args.wandb_name
- wandb.init(
- entity=args.wandb_entity,
- project=project,
- name=name,
- dir=output_dir,
- id=args.wandb_id,
- resume="allow",
- )
- wandb.config.update(args)
- else:
- # wandb also supports grouping for distributed training,
- # but we only logs aggregated data,
- # so it's enough to perform on rank0 node.
- args.use_wandb = False
- # Don't give args to trainer.run() directly!!!
- # Instead of it, define "Options" object and build here.
- trainer_options = cls.trainer.build_options(args)
- cls.trainer.run(
- model=model,
- optimizers=optimizers,
- schedulers=schedulers,
- train_iter_factory=train_iter_factory,
- valid_iter_factory=valid_iter_factory,
- trainer_options=trainer_options,
- distributed_option=distributed_option,
- )
- if args.use_wandb and wandb.run:
- wandb.finish()
- @classmethod
- def build_iter_options(
- cls,
- args: argparse.Namespace,
- distributed_option: DistributedOption,
- mode: str,
- ):
- if mode == "train":
- preprocess_fn = cls.build_preprocess_fn(args, train=True)
- collate_fn = cls.build_collate_fn(args, train=True)
- data_path_and_name_and_type = args.train_data_path_and_name_and_type
- shape_files = args.train_shape_file
- batch_size = args.batch_size
- batch_bins = args.batch_bins
- batch_type = args.batch_type
- max_cache_size = args.max_cache_size
- max_cache_fd = args.max_cache_fd
- distributed = distributed_option.distributed
- num_batches = None
- num_iters_per_epoch = args.num_iters_per_epoch
- train = True
- elif mode == "valid":
- preprocess_fn = cls.build_preprocess_fn(args, train=False)
- collate_fn = cls.build_collate_fn(args, train=False)
- data_path_and_name_and_type = args.valid_data_path_and_name_and_type
- shape_files = args.valid_shape_file
- if args.valid_batch_type is None:
- batch_type = args.batch_type
- else:
- batch_type = args.valid_batch_type
- if args.valid_batch_size is None:
- batch_size = args.batch_size
- else:
- batch_size = args.valid_batch_size
- if args.valid_batch_bins is None:
- batch_bins = args.batch_bins
- else:
- batch_bins = args.valid_batch_bins
- if args.valid_max_cache_size is None:
- # Cache 5% of maximum size for validation loader
- max_cache_size = 0.05 * args.max_cache_size
- else:
- max_cache_size = args.valid_max_cache_size
- max_cache_fd = args.max_cache_fd
- distributed = distributed_option.distributed
- num_batches = None
- num_iters_per_epoch = None
- train = False
- else:
- raise NotImplementedError(f"mode={mode}")
- return IteratorOptions(
- preprocess_fn=preprocess_fn,
- collate_fn=collate_fn,
- data_path_and_name_and_type=data_path_and_name_and_type,
- shape_files=shape_files,
- batch_type=batch_type,
- batch_size=batch_size,
- batch_bins=batch_bins,
- num_batches=num_batches,
- max_cache_size=max_cache_size,
- max_cache_fd=max_cache_fd,
- distributed=distributed,
- num_iters_per_epoch=num_iters_per_epoch,
- train=train,
- )
- @classmethod
- def build_iter_factory(
- cls,
- args: argparse.Namespace,
- distributed_option: DistributedOption,
- mode: str,
- kwargs: dict = None,
- ) -> AbsIterFactory:
- """Build a factory object of mini-batch iterator.
- This object is invoked at every epochs to build the iterator for each epoch
- as following:
- >>> iter_factory = cls.build_iter_factory(...)
- >>> for epoch in range(1, max_epoch):
- ... for keys, batch in iter_fatory.build_iter(epoch):
- ... model(**batch)
- The mini-batches for each epochs are fully controlled by this class.
- Note that the random seed used for shuffling is decided as "seed + epoch" and
- the generated mini-batches can be reproduces when resuming.
- Note that the definition of "epoch" doesn't always indicate
- to run out of the whole training corpus.
- "--num_iters_per_epoch" option restricts the number of iterations for each epoch
- and the rest of samples for the originally epoch are left for the next epoch.
- e.g. If The number of mini-batches equals to 4, the following two are same:
- - 1 epoch without "--num_iters_per_epoch"
- - 4 epoch with "--num_iters_per_epoch" == 4
- """
- iter_options = cls.build_iter_options(args, distributed_option, mode)
- # Overwrite iter_options if any kwargs is given
- if kwargs is not None:
- for k, v in kwargs.items():
- setattr(iter_options, k, v)
- if args.iterator_type == "sequence":
- return cls.build_sequence_iter_factory(
- args=args,
- iter_options=iter_options,
- mode=mode,
- )
- elif args.iterator_type == "chunk":
- return cls.build_chunk_iter_factory(
- args=args,
- iter_options=iter_options,
- mode=mode,
- )
- elif args.iterator_type == "task":
- return cls.build_task_iter_factory(
- args=args,
- iter_options=iter_options,
- mode=mode,
- )
- else:
- raise RuntimeError(f"Not supported: iterator_type={args.iterator_type}")
- @classmethod
- def build_sequence_iter_factory(
- cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str
- ) -> AbsIterFactory:
- if hasattr(args, "frontend_conf"):
- if args.frontend_conf is not None and "fs" in args.frontend_conf:
- dest_sample_rate = args.frontend_conf["fs"]
- else:
- dest_sample_rate = 16000
- else:
- dest_sample_rate = 16000
- dataset = ESPnetDataset(
- iter_options.data_path_and_name_and_type,
- float_dtype=args.train_dtype,
- preprocess=iter_options.preprocess_fn,
- max_cache_size=iter_options.max_cache_size,
- max_cache_fd=iter_options.max_cache_fd,
- dest_sample_rate=dest_sample_rate,
- )
- cls.check_task_requirements(
- dataset, args.allow_variable_data_keys, train=iter_options.train
- )
- if Path(
- Path(iter_options.data_path_and_name_and_type[0][0]).parent, "utt2category"
- ).exists():
- utt2category_file = str(
- Path(
- Path(iter_options.data_path_and_name_and_type[0][0]).parent,
- "utt2category",
- )
- )
- else:
- utt2category_file = None
- batch_sampler = build_batch_sampler(
- type=iter_options.batch_type,
- shape_files=iter_options.shape_files,
- fold_lengths=args.fold_length,
- batch_size=iter_options.batch_size,
- batch_bins=iter_options.batch_bins,
- sort_in_batch=args.sort_in_batch,
- sort_batch=args.sort_batch,
- drop_last=False,
- min_batch_size=torch.distributed.get_world_size()
- if iter_options.distributed
- else 1,
- utt2category_file=utt2category_file,
- )
- batches = list(batch_sampler)
- if iter_options.num_batches is not None:
- batches = batches[: iter_options.num_batches]
- bs_list = [len(batch) for batch in batches]
- logging.info(f"[{mode}] dataset:\n{dataset}")
- logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
- logging.info(
- f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
- f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
- )
- if args.scheduler == "tri_stage" and mode == "train":
- args.max_update = len(bs_list) * args.max_epoch
- logging.info("Max update: {}".format(args.max_update))
- if iter_options.distributed:
- world_size = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
- for batch in batches:
- if len(batch) < world_size:
- raise RuntimeError(
- f"The batch-size must be equal or more than world_size: "
- f"{len(batch)} < {world_size}"
- )
- batches = [batch[rank::world_size] for batch in batches]
- return SequenceIterFactory(
- dataset=dataset,
- batches=batches,
- seed=args.seed,
- num_iters_per_epoch=iter_options.num_iters_per_epoch,
- shuffle=iter_options.train,
- num_workers=args.num_workers,
- collate_fn=iter_options.collate_fn,
- pin_memory=args.ngpu > 0,
- )
- @classmethod
- def build_chunk_iter_factory(
- cls,
- args: argparse.Namespace,
- iter_options: IteratorOptions,
- mode: str,
- ) -> AbsIterFactory:
- dataset = ESPnetDataset(
- iter_options.data_path_and_name_and_type,
- float_dtype=args.train_dtype,
- preprocess=iter_options.preprocess_fn,
- max_cache_size=iter_options.max_cache_size,
- max_cache_fd=iter_options.max_cache_fd,
- )
- cls.check_task_requirements(
- dataset, args.allow_variable_data_keys, train=iter_options.train
- )
- if len(iter_options.shape_files) == 0:
- key_file = iter_options.data_path_and_name_and_type[0][0]
- else:
- key_file = iter_options.shape_files[0]
- batch_sampler = UnsortedBatchSampler(batch_size=1, key_file=key_file)
- batches = list(batch_sampler)
- if iter_options.num_batches is not None:
- batches = batches[: iter_options.num_batches]
- logging.info(f"[{mode}] dataset:\n{dataset}")
- if iter_options.distributed:
- world_size = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
- if len(batches) < world_size:
- raise RuntimeError("Number of samples is smaller than world_size")
- if iter_options.batch_size < world_size:
- raise RuntimeError("batch_size must be equal or more than world_size")
- if rank < iter_options.batch_size % world_size:
- batch_size = iter_options.batch_size // world_size + 1
- else:
- batch_size = iter_options.batch_size // world_size
- num_cache_chunks = args.num_cache_chunks // world_size
- # NOTE(kamo): Split whole corpus by sample numbers without considering
- # each of the lengths, therefore the number of iteration counts are not
- # always equal to each other and the iterations are limitted
- # by the fewest iterations.
- # i.e. the samples over the counts are discarded.
- batches = batches[rank::world_size]
- else:
- batch_size = iter_options.batch_size
- num_cache_chunks = args.num_cache_chunks
- return ChunkIterFactory(
- dataset=dataset,
- batches=batches,
- seed=args.seed,
- batch_size=batch_size,
- # For chunk iterator,
- # --num_iters_per_epoch doesn't indicate the number of iterations,
- # but indicates the number of samples.
- num_samples_per_epoch=iter_options.num_iters_per_epoch,
- shuffle=iter_options.train,
- num_workers=args.num_workers,
- collate_fn=iter_options.collate_fn,
- pin_memory=args.ngpu > 0,
- chunk_length=args.chunk_length,
- chunk_shift_ratio=args.chunk_shift_ratio,
- num_cache_chunks=num_cache_chunks,
- )
- # NOTE(kamo): Not abstract class
- @classmethod
- def build_task_iter_factory(
- cls,
- args: argparse.Namespace,
- iter_options: IteratorOptions,
- mode: str,
- ) -> AbsIterFactory:
- """Build task specific iterator factory
- Example:
- >>> class YourTask(AbsTask):
- ... @classmethod
- ... def add_task_arguments(cls, parser: argparse.ArgumentParser):
- ... parser.set_defaults(iterator_type="task")
- ...
- ... @classmethod
- ... def build_task_iter_factory(
- ... cls,
- ... args: argparse.Namespace,
- ... iter_options: IteratorOptions,
- ... mode: str,
- ... ):
- ... return FooIterFactory(...)
- ...
- ... @classmethod
- ... def build_iter_options(
- .... args: argparse.Namespace,
- ... distributed_option: DistributedOption,
- ... mode: str
- ... ):
- ... # if you need to customize options object
- """
- raise NotImplementedError
- @classmethod
- def build_multiple_iter_factory(
- cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str
- ):
- iter_options = cls.build_iter_options(args, distributed_option, mode)
- assert len(iter_options.data_path_and_name_and_type) > 0, len(
- iter_options.data_path_and_name_and_type
- )
- # 1. Sanity check
- num_splits = None
- for path in [
- path for path, _, _ in iter_options.data_path_and_name_and_type
- ] + list(iter_options.shape_files):
- if not Path(path).is_dir():
- raise RuntimeError(f"{path} is not a directory")
- p = Path(path) / "num_splits"
- if not p.exists():
- raise FileNotFoundError(f"{p} is not found")
- with p.open() as f:
- _num_splits = int(f.read())
- if num_splits is not None and num_splits != _num_splits:
- raise RuntimeError(
- f"Number of splits are mismathed: "
- f"{iter_options.data_path_and_name_and_type[0][0]} and {path}"
- )
- num_splits = _num_splits
- for i in range(num_splits):
- p = Path(path) / f"split.{i}"
- if not p.exists():
- raise FileNotFoundError(f"{p} is not found")
- # 2. Create functions to build an iter factory for each splits
- data_path_and_name_and_type_list = [
- [
- (str(Path(p) / f"split.{i}"), n, t)
- for p, n, t in iter_options.data_path_and_name_and_type
- ]
- for i in range(num_splits)
- ]
- shape_files_list = [
- [str(Path(s) / f"split.{i}") for s in iter_options.shape_files]
- for i in range(num_splits)
- ]
- num_iters_per_epoch_list = [
- (iter_options.num_iters_per_epoch + i) // num_splits
- if iter_options.num_iters_per_epoch is not None
- else None
- for i in range(num_splits)
- ]
- max_cache_size = iter_options.max_cache_size / num_splits
- # Note that iter-factories are built for each epoch at runtime lazily.
- build_funcs = [
- functools.partial(
- cls.build_iter_factory,
- args,
- distributed_option,
- mode,
- kwargs=dict(
- data_path_and_name_and_type=_data_path_and_name_and_type,
- shape_files=_shape_files,
- num_iters_per_epoch=_num_iters_per_epoch,
- max_cache_size=max_cache_size,
- ),
- )
- for (
- _data_path_and_name_and_type,
- _shape_files,
- _num_iters_per_epoch,
- ) in zip(
- data_path_and_name_and_type_list,
- shape_files_list,
- num_iters_per_epoch_list,
- )
- ]
- # 3. Build MultipleIterFactory
- return MultipleIterFactory(
- build_funcs=build_funcs, shuffle=iter_options.train, seed=args.seed
- )
- @classmethod
- def build_streaming_iterator(
- cls,
- data_path_and_name_and_type,
- preprocess_fn,
- collate_fn,
- key_file: str = None,
- batch_size: int = 1,
- fs: dict = None,
- mc: bool = False,
- dtype: str = np.float32,
- num_workers: int = 1,
- allow_variable_data_keys: bool = False,
- ngpu: int = 0,
- inference: bool = False,
- ) -> DataLoader:
- """Build DataLoader using iterable dataset"""
- # For backward compatibility for pytorch DataLoader
- if collate_fn is not None:
- kwargs = dict(collate_fn=collate_fn)
- else:
- kwargs = {}
- dataset = IterableESPnetDataset(
- data_path_and_name_and_type,
- float_dtype=dtype,
- fs=fs,
- mc=mc,
- preprocess=preprocess_fn,
- key_file=key_file,
- )
- if dataset.apply_utt2category:
- kwargs.update(batch_size=1)
- else:
- kwargs.update(batch_size=batch_size)
- cls.check_task_requirements(
- dataset, allow_variable_data_keys, train=False, inference=inference
- )
- return DataLoader(
- dataset=dataset,
- pin_memory=ngpu > 0,
- num_workers=num_workers,
- **kwargs,
- )
- # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
- @classmethod
- def build_model_from_file(
- cls,
- config_file: Union[Path, str] = None,
- model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- ) -> Tuple[FunASRModel, argparse.Namespace]:
- """Build model from the files.
- This method is used for inference or fine-tuning.
- Args:
- config_file: The yaml file saved when training.
- model_file: The model file saved when training.
- device: Device type, "cpu", "cuda", or "cuda:N".
- """
- if config_file is None:
- assert model_file is not None, (
- "The argument 'model_file' must be provided "
- "if the argument 'config_file' is not specified."
- )
- config_file = Path(model_file).parent / "config.yaml"
- else:
- config_file = Path(config_file)
- with config_file.open("r", encoding="utf-8") as f:
- args = yaml.safe_load(f)
- if cmvn_file is not None:
- args["cmvn_file"] = cmvn_file
- args = argparse.Namespace(**args)
- model = cls.build_model(args)
- if not isinstance(model, FunASRModel):
- raise RuntimeError(
- f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
- )
- model.to(device)
- if model_file is not None:
- if device == "cuda":
- # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
- # in PyTorch<=1.4
- device = f"cuda:{torch.cuda.current_device()}"
- model.load_state_dict(torch.load(model_file, map_location=device))
- model.to(device)
- return model, args
|