| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873 |
- # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Abstract task module."""
- import argparse
- import functools
- import logging
- import os
- import sys
- from abc import ABC
- from abc import abstractmethod
- from dataclasses import dataclass
- from distutils.version import LooseVersion
- from io import BytesIO
- from pathlib import Path
- from typing import Any
- from typing import Callable
- from typing import Dict
- from typing import List
- from typing import Optional
- from typing import Sequence
- from typing import Tuple
- from typing import Union
- import humanfriendly
- import numpy as np
- import torch
- import torch.distributed as dist
- import torch.multiprocessing
- import torch.nn
- import torch.optim
- import yaml
- from torch.utils.data import DataLoader
- from typeguard import check_argument_types
- from typeguard import check_return_type
- from funasr import __version__
- from funasr.datasets.dataset import AbsDataset
- from funasr.datasets.dataset import DATA_TYPES
- from funasr.datasets.dataset import ESPnetDataset
- from funasr.datasets.iterable_dataset import IterableESPnetDataset
- from funasr.iterators.abs_iter_factory import AbsIterFactory
- from funasr.iterators.chunk_iter_factory import ChunkIterFactory
- from funasr.iterators.multiple_iter_factory import MultipleIterFactory
- from funasr.iterators.sequence_iter_factory import SequenceIterFactory
- from funasr.optimizers.sgd import SGD
- from funasr.optimizers.fairseq_adam import FairseqAdam
- from funasr.samplers.build_batch_sampler import BATCH_TYPES
- from funasr.samplers.build_batch_sampler import build_batch_sampler
- from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
- from funasr.schedulers.noam_lr import NoamLR
- from funasr.schedulers.warmup_lr import WarmupLR
- from funasr.schedulers.tri_stage_scheduler import TriStageLR
- from funasr.torch_utils.load_pretrained_model import load_pretrained_model
- from funasr.torch_utils.model_summary import model_summary
- from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
- from funasr.torch_utils.set_all_random_seed import set_all_random_seed
- from funasr.train.abs_espnet_model import AbsESPnetModel
- from funasr.train.class_choices import ClassChoices
- from funasr.train.distributed_utils import DistributedOption
- from funasr.train.trainer import Trainer
- from funasr.utils import config_argparse
- from funasr.utils.build_dataclass import build_dataclass
- from funasr.utils.cli_utils import get_commandline_args
- from funasr.utils.get_default_kwargs import get_default_kwargs
- from funasr.utils.nested_dict_action import NestedDictAction
- from funasr.utils.types import humanfriendly_parse_size_or_none
- from funasr.utils.types import int_or_none
- from funasr.utils.types import str2bool
- from funasr.utils.types import str2triple_str
- from funasr.utils.types import str_or_int
- from funasr.utils.types import str_or_none
- from funasr.utils.wav_utils import calc_shape, generate_data_list
- from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
- try:
- import wandb
- except Exception:
- wandb = None
- if LooseVersion(torch.__version__) >= LooseVersion("1.5.0"):
- pass
- else:
- pass
- optim_classes = dict(
- adam=torch.optim.Adam,
- fairseq_adam=FairseqAdam,
- adamw=torch.optim.AdamW,
- sgd=SGD,
- adadelta=torch.optim.Adadelta,
- adagrad=torch.optim.Adagrad,
- adamax=torch.optim.Adamax,
- asgd=torch.optim.ASGD,
- lbfgs=torch.optim.LBFGS,
- rmsprop=torch.optim.RMSprop,
- rprop=torch.optim.Rprop,
- )
- if LooseVersion(torch.__version__) >= LooseVersion("1.10.0"):
- # From 1.10.0, RAdam is officially supported
- optim_classes.update(
- radam=torch.optim.RAdam,
- )
- try:
- import torch_optimizer
- optim_classes.update(
- accagd=torch_optimizer.AccSGD,
- adabound=torch_optimizer.AdaBound,
- adamod=torch_optimizer.AdaMod,
- diffgrad=torch_optimizer.DiffGrad,
- lamb=torch_optimizer.Lamb,
- novograd=torch_optimizer.NovoGrad,
- pid=torch_optimizer.PID,
- # torch_optimizer<=0.0.1a10 doesn't support
- # qhadam=torch_optimizer.QHAdam,
- qhm=torch_optimizer.QHM,
- sgdw=torch_optimizer.SGDW,
- yogi=torch_optimizer.Yogi,
- )
- if LooseVersion(torch_optimizer.__version__) < LooseVersion("0.2.0"):
- # From 0.2.0, RAdam is dropped
- optim_classes.update(
- radam=torch_optimizer.RAdam,
- )
- del torch_optimizer
- except ImportError:
- pass
- try:
- import apex
- optim_classes.update(
- fusedadam=apex.optimizers.FusedAdam,
- fusedlamb=apex.optimizers.FusedLAMB,
- fusednovograd=apex.optimizers.FusedNovoGrad,
- fusedsgd=apex.optimizers.FusedSGD,
- )
- del apex
- except ImportError:
- pass
- try:
- import fairscale
- except ImportError:
- fairscale = None
- scheduler_classes = dict(
- ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
- lambdalr=torch.optim.lr_scheduler.LambdaLR,
- steplr=torch.optim.lr_scheduler.StepLR,
- multisteplr=torch.optim.lr_scheduler.MultiStepLR,
- exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
- CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
- noamlr=NoamLR,
- warmuplr=WarmupLR,
- tri_stage=TriStageLR,
- cycliclr=torch.optim.lr_scheduler.CyclicLR,
- onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
- CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
- )
- # To lower keys
- optim_classes = {k.lower(): v for k, v in optim_classes.items()}
- scheduler_classes = {k.lower(): v for k, v in scheduler_classes.items()}
- @dataclass
- class IteratorOptions:
- preprocess_fn: callable
- collate_fn: callable
- data_path_and_name_and_type: list
- shape_files: list
- batch_size: int
- batch_bins: int
- batch_type: str
- max_cache_size: float
- max_cache_fd: int
- distributed: bool
- num_batches: Optional[int]
- num_iters_per_epoch: Optional[int]
- train: bool
- class AbsTask(ABC):
- # Use @staticmethod, or @classmethod,
- # instead of instance method to avoid God classes
- # If you need more than one optimizers, change this value in inheritance
- num_optimizers: int = 1
- trainer = Trainer
- class_choices_list: List[ClassChoices] = []
- finetune_args: None
- def __init__(self):
- raise RuntimeError("This class can't be instantiated.")
- @classmethod
- @abstractmethod
- def add_task_arguments(cls, parser: argparse.ArgumentParser):
- pass
- @classmethod
- @abstractmethod
- def build_collate_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Callable[[Sequence[Dict[str, np.ndarray]]], Dict[str, torch.Tensor]]:
- """Return "collate_fn", which is a callable object and given to DataLoader.
- >>> from torch.utils.data import DataLoader
- >>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...)
- In many cases, you can use our common collate_fn.
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def build_preprocess_fn(
- cls, args: argparse.Namespace, train: bool
- ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def required_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- """Define the required names by Task
- This function is used by
- >>> cls.check_task_requirements()
- If your model is defined as following,
- >>> from funasr.train.abs_espnet_model import AbsESPnetModel
- >>> class Model(AbsESPnetModel):
- ... def forward(self, input, output, opt=None): pass
- then "required_data_names" should be as
- >>> required_data_names = ('input', 'output')
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def optional_data_names(
- cls, train: bool = True, inference: bool = False
- ) -> Tuple[str, ...]:
- """Define the optional names by Task
- This function is used by
- >>> cls.check_task_requirements()
- If your model is defined as follows,
- >>> from funasr.train.abs_espnet_model import AbsESPnetModel
- >>> class Model(AbsESPnetModel):
- ... def forward(self, input, output, opt=None): pass
- then "optional_data_names" should be as
- >>> optional_data_names = ('opt',)
- """
- raise NotImplementedError
- @classmethod
- @abstractmethod
- def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
- raise NotImplementedError
- @classmethod
- def get_parser(cls) -> config_argparse.ArgumentParser:
- assert check_argument_types()
- class ArgumentDefaultsRawTextHelpFormatter(
- argparse.RawTextHelpFormatter,
- argparse.ArgumentDefaultsHelpFormatter,
- ):
- pass
- parser = config_argparse.ArgumentParser(
- description="base parser",
- formatter_class=ArgumentDefaultsRawTextHelpFormatter,
- )
- # NOTE(kamo): Use '_' instead of '-' to avoid confusion.
- # I think '-' looks really confusing if it's written in yaml.
- # NOTE(kamo): add_arguments(..., required=True) can't be used
- # to provide --print_config mode. Instead of it, do as
- # parser.set_defaults(required=["output_dir"])
- group = parser.add_argument_group("Common configuration")
- group.add_argument(
- "--print_config",
- action="store_true",
- help="Print the config file and exit",
- )
- group.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
- group.add_argument(
- "--dry_run",
- type=str2bool,
- default=False,
- help="Perform process without training",
- )
- group.add_argument(
- "--iterator_type",
- type=str,
- choices=["sequence", "chunk", "task", "none"],
- default="sequence",
- help="Specify iterator type",
- )
- group.add_argument("--output_dir", type=str_or_none, default=None)
- group.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- group.add_argument("--seed", type=int, default=0, help="Random seed")
- group.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
- group.add_argument(
- "--num_att_plot",
- type=int,
- default=3,
- help="The number images to plot the outputs from attention. "
- "This option makes sense only when attention-based model. "
- "We can also disable the attention plot by setting it 0",
- )
- group = parser.add_argument_group("distributed training related")
- group.add_argument(
- "--dist_backend",
- default="nccl",
- type=str,
- help="distributed backend",
- )
- group.add_argument(
- "--dist_init_method",
- type=str,
- default="env://",
- help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
- '"WORLD_SIZE", and "RANK" are referred.',
- )
- group.add_argument(
- "--dist_world_size",
- default=None,
- type=int_or_none,
- help="number of nodes for distributed training",
- )
- group.add_argument(
- "--dist_rank",
- type=int_or_none,
- default=None,
- help="node rank for distributed training",
- )
- group.add_argument(
- # Not starting with "dist_" for compatibility to launch.py
- "--local_rank",
- type=int_or_none,
- default=None,
- help="local rank for distributed training. This option is used if "
- "--multiprocessing_distributed=false",
- )
- group.add_argument(
- "--dist_master_addr",
- default=None,
- type=str_or_none,
- help="The master address for distributed training. "
- "This value is used when dist_init_method == 'env://'",
- )
- group.add_argument(
- "--dist_master_port",
- default=None,
- type=int_or_none,
- help="The master port for distributed training"
- "This value is used when dist_init_method == 'env://'",
- )
- group.add_argument(
- "--dist_launcher",
- default=None,
- type=str_or_none,
- choices=["slurm", "mpi", None],
- help="The launcher type for distributed training",
- )
- group.add_argument(
- "--multiprocessing_distributed",
- default=False,
- type=str2bool,
- help="Use multi-processing distributed training to launch "
- "N processes per node, which has N GPUs. This is the "
- "fastest way to use PyTorch for either single node or "
- "multi node data parallel training",
- )
- group.add_argument(
- "--unused_parameters",
- type=str2bool,
- default=False,
- help="Whether to use the find_unused_parameters in "
- "torch.nn.parallel.DistributedDataParallel ",
- )
- group.add_argument(
- "--sharded_ddp",
- default=False,
- type=str2bool,
- help="Enable sharded training provided by fairscale",
- )
- group = parser.add_argument_group("cudnn mode related")
- group.add_argument(
- "--cudnn_enabled",
- type=str2bool,
- default=torch.backends.cudnn.enabled,
- help="Enable CUDNN",
- )
- group.add_argument(
- "--cudnn_benchmark",
- type=str2bool,
- default=torch.backends.cudnn.benchmark,
- help="Enable cudnn-benchmark mode",
- )
- group.add_argument(
- "--cudnn_deterministic",
- type=str2bool,
- default=True,
- help="Enable cudnn-deterministic mode",
- )
- group = parser.add_argument_group("collect stats mode related")
- group.add_argument(
- "--collect_stats",
- type=str2bool,
- default=False,
- help='Perform on "collect stats" mode',
- )
- group.add_argument(
- "--write_collected_feats",
- type=str2bool,
- default=False,
- help='Write the output features from the model when "collect stats" mode',
- )
- group = parser.add_argument_group("Trainer related")
- group.add_argument(
- "--max_epoch",
- type=int,
- default=40,
- help="The maximum number epoch to train",
- )
- group.add_argument(
- "--max_update",
- type=int,
- default=sys.maxsize,
- help="The maximum number update step to train",
- )
- group.add_argument(
- "--patience",
- type=int_or_none,
- default=None,
- help="Number of epochs to wait without improvement "
- "before stopping the training",
- )
- group.add_argument(
- "--val_scheduler_criterion",
- type=str,
- nargs=2,
- default=("valid", "loss"),
- help="The criterion used for the value given to the lr scheduler. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'and the criterion name. The mode specifying "min" or "max" can '
- "be changed by --scheduler_conf",
- )
- group.add_argument(
- "--early_stopping_criterion",
- type=str,
- nargs=3,
- default=("valid", "loss", "min"),
- help="The criterion used for judging of early stopping. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
- )
- group.add_argument(
- "--best_model_criterion",
- type=str2triple_str,
- nargs="+",
- default=[
- ("train", "loss", "min"),
- ("valid", "loss", "min"),
- ("train", "acc", "max"),
- ("valid", "acc", "max"),
- ],
- help="The criterion used for judging of the best model. "
- 'Give a pair referring the phase, "train" or "valid",'
- 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
- )
- group.add_argument(
- "--keep_nbest_models",
- type=int,
- nargs="+",
- default=[10],
- help="Remove previous snapshots excluding the n-best scored epochs",
- )
- group.add_argument(
- "--nbest_averaging_interval",
- type=int,
- default=0,
- help="The epoch interval to apply model averaging and save nbest models",
- )
- group.add_argument(
- "--grad_clip",
- type=float,
- default=5.0,
- help="Gradient norm threshold to clip",
- )
- group.add_argument(
- "--grad_clip_type",
- type=float,
- default=2.0,
- help="The type of the used p-norm for gradient clip. Can be inf",
- )
- group.add_argument(
- "--grad_noise",
- type=str2bool,
- default=False,
- help="The flag to switch to use noise injection to "
- "gradients during training",
- )
- group.add_argument(
- "--accum_grad",
- type=int,
- default=1,
- help="The number of gradient accumulation",
- )
- group.add_argument(
- "--no_forward_run",
- type=str2bool,
- default=False,
- help="Just only iterating data loading without "
- "model forwarding and training",
- )
- group.add_argument(
- "--resume",
- type=str2bool,
- default=False,
- help="Enable resuming if checkpoint is existing",
- )
- group.add_argument(
- "--train_dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type for training.",
- )
- group.add_argument(
- "--use_amp",
- type=str2bool,
- default=False,
- help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
- )
- group.add_argument(
- "--log_interval",
- type=int_or_none,
- default=None,
- help="Show the logs every the number iterations in each epochs at the "
- "training phase. If None is given, it is decided according the number "
- "of training samples automatically .",
- )
- group.add_argument(
- "--use_tensorboard",
- type=str2bool,
- default=True,
- help="Enable tensorboard logging",
- )
- group.add_argument(
- "--use_wandb",
- type=str2bool,
- default=False,
- help="Enable wandb logging",
- )
- group.add_argument(
- "--wandb_project",
- type=str,
- default=None,
- help="Specify wandb project",
- )
- group.add_argument(
- "--wandb_id",
- type=str,
- default=None,
- help="Specify wandb id",
- )
- group.add_argument(
- "--wandb_entity",
- type=str,
- default=None,
- help="Specify wandb entity",
- )
- group.add_argument(
- "--wandb_name",
- type=str,
- default=None,
- help="Specify wandb run name",
- )
- group.add_argument(
- "--wandb_model_log_interval",
- type=int,
- default=-1,
- help="Set the model log period",
- )
- group.add_argument(
- "--detect_anomaly",
- type=str2bool,
- default=False,
- help="Set torch.autograd.set_detect_anomaly",
- )
- group = parser.add_argument_group("Pretraining model related")
- group.add_argument("--pretrain_path", help="This option is obsoleted")
- group.add_argument(
- "--init_param",
- type=str,
- default=[],
- nargs="*",
- help="Specify the file path used for initialization of parameters. "
- "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
- "where file_path is the model file path, "
- "src_key specifies the key of model states to be used in the model file, "
- "dst_key specifies the attribute of the model to be initialized, "
- "and exclude_keys excludes keys of model states for the initialization."
- "e.g.\n"
- " # Load all parameters"
- " --init_param some/where/model.pth\n"
- " # Load only decoder parameters"
- " --init_param some/where/model.pth:decoder:decoder\n"
- " # Load only decoder parameters excluding decoder.embed"
- " --init_param some/where/model.pth:decoder:decoder:decoder.embed\n"
- " --init_param some/where/model.pth:decoder:decoder:decoder.embed\n",
- )
- group.add_argument(
- "--ignore_init_mismatch",
- type=str2bool,
- default=False,
- help="Ignore size mismatch when loading pre-trained model",
- )
- group.add_argument(
- "--freeze_param",
- type=str,
- default=[],
- nargs="*",
- help="Freeze parameters",
- )
- group = parser.add_argument_group("BatchSampler related")
- group.add_argument(
- "--num_iters_per_epoch",
- type=int_or_none,
- default=None,
- help="Restrict the number of iterations for training per epoch",
- )
- group.add_argument(
- "--batch_size",
- type=int,
- default=20,
- help="The mini-batch size used for training. Used if batch_type='unsorted',"
- " 'sorted', or 'folded'.",
- )
- group.add_argument(
- "--valid_batch_size",
- type=int_or_none,
- default=None,
- help="If not given, the value of --batch_size is used",
- )
- group.add_argument(
- "--batch_bins",
- type=int,
- default=1000000,
- help="The number of batch bins. Used if batch_type='length' or 'numel'",
- )
- group.add_argument(
- "--valid_batch_bins",
- type=int_or_none,
- default=None,
- help="If not given, the value of --batch_bins is used",
- )
- group.add_argument("--train_shape_file", type=str, action="append", default=[])
- group.add_argument("--valid_shape_file", type=str, action="append", default=[])
- group = parser.add_argument_group("Sequence iterator related")
- _batch_type_help = ""
- for key, value in BATCH_TYPES.items():
- _batch_type_help += f'"{key}":\n{value}\n'
- group.add_argument(
- "--batch_type",
- type=str,
- default="length",
- choices=list(BATCH_TYPES),
- help=_batch_type_help,
- )
- group.add_argument(
- "--valid_batch_type",
- type=str_or_none,
- default=None,
- choices=list(BATCH_TYPES) + [None],
- help="If not given, the value of --batch_type is used",
- )
- group.add_argument(
- "--speech_length_min",
- type=int,
- default=-1,
- help="speech length min",
- )
- group.add_argument(
- "--speech_length_max",
- type=int,
- default=-1,
- help="speech length max",
- )
- group.add_argument("--fold_length", type=int, action="append", default=[])
- group.add_argument(
- "--sort_in_batch",
- type=str,
- default="descending",
- choices=["descending", "ascending"],
- help="Sort the samples in each mini-batches by the sample "
- 'lengths. To enable this, "shape_file" must have the length information.',
- )
- group.add_argument(
- "--sort_batch",
- type=str,
- default="descending",
- choices=["descending", "ascending"],
- help="Sort mini-batches by the sample lengths",
- )
- group.add_argument(
- "--multiple_iterator",
- type=str2bool,
- default=False,
- help="Use multiple iterator mode",
- )
- group = parser.add_argument_group("Chunk iterator related")
- group.add_argument(
- "--chunk_length",
- type=str_or_int,
- default=500,
- help="Specify chunk length. e.g. '300', '300,400,500', or '300-400'."
- "If multiple numbers separated by command are given, "
- "one of them is selected randomly for each samples. "
- "If two numbers are given with '-', it indicates the range of the choices. "
- "Note that if the sequence length is shorter than the all chunk_lengths, "
- "the sample is discarded. ",
- )
- group.add_argument(
- "--chunk_shift_ratio",
- type=float,
- default=0.5,
- help="Specify the shift width of chunks. If it's less than 1, "
- "allows the overlapping and if bigger than 1, there are some gaps "
- "between each chunk.",
- )
- group.add_argument(
- "--num_cache_chunks",
- type=int,
- default=1024,
- help="Shuffle in the specified number of chunks and generate mini-batches "
- "More larger this value, more randomness can be obtained.",
- )
- group = parser.add_argument_group("Dataset related")
- _data_path_and_name_and_type_help = (
- "Give three words splitted by comma. It's used for the training data. "
- "e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. "
- "The first value, some/path/a.scp, indicates the file path, "
- "and the second, foo, is the key name used for the mini-batch data, "
- "and the last, sound, decides the file type. "
- "This option is repeatable, so you can input any number of features "
- "for your task. Supported file types are as follows:\n\n"
- )
- for key, dic in DATA_TYPES.items():
- _data_path_and_name_and_type_help += f'"{key}":\n{dic["help"]}\n\n'
- # for large dataset
- group.add_argument(
- "--dataset_type",
- type=str,
- default="small",
- help="whether to use dataloader for large dataset",
- )
- parser.add_argument(
- "--dataset_conf",
- action=NestedDictAction,
- default=dict(),
- help=f"The keyword arguments for dataset",
- )
- group.add_argument(
- "--train_data_file",
- type=str,
- default=None,
- help="train_list for large dataset",
- )
- group.add_argument(
- "--valid_data_file",
- type=str,
- default=None,
- help="valid_list for large dataset",
- )
- group.add_argument(
- "--train_data_path_and_name_and_type",
- type=str2triple_str,
- action="append",
- default=[],
- help=_data_path_and_name_and_type_help,
- )
- group.add_argument(
- "--valid_data_path_and_name_and_type",
- type=str2triple_str,
- action="append",
- default=[],
- )
- group.add_argument(
- "--allow_variable_data_keys",
- type=str2bool,
- default=False,
- help="Allow the arbitrary keys for mini-batch with ignoring "
- "the task requirements",
- )
- group.add_argument(
- "--max_cache_size",
- type=humanfriendly.parse_size,
- default=0.0,
- help="The maximum cache size for data loader. e.g. 10MB, 20GB.",
- )
- group.add_argument(
- "--max_cache_fd",
- type=int,
- default=32,
- help="The maximum number of file descriptors to be kept "
- "as opened for ark files. "
- "This feature is only valid when data type is 'kaldi_ark'.",
- )
- group.add_argument(
- "--valid_max_cache_size",
- type=humanfriendly_parse_size_or_none,
- default=None,
- help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. "
- "If None, the 5 percent size of --max_cache_size",
- )
- group = parser.add_argument_group("Optimizer related")
- for i in range(1, cls.num_optimizers + 1):
- suf = "" if i == 1 else str(i)
- group.add_argument(
- f"--optim{suf}",
- type=lambda x: x.lower(),
- default="adadelta",
- choices=list(optim_classes),
- help="The optimizer type",
- )
- group.add_argument(
- f"--optim{suf}_conf",
- action=NestedDictAction,
- default=dict(),
- help="The keyword arguments for optimizer",
- )
- group.add_argument(
- f"--scheduler{suf}",
- type=lambda x: str_or_none(x.lower()),
- default=None,
- choices=list(scheduler_classes) + [None],
- help="The lr scheduler type",
- )
- group.add_argument(
- f"--scheduler{suf}_conf",
- action=NestedDictAction,
- default=dict(),
- help="The keyword arguments for lr scheduler",
- )
- # for training on PAI
- group = parser.add_argument_group("PAI training related")
- group.add_argument(
- "--use_pai",
- type=str2bool,
- default=False,
- help="flag to indicate whether training on PAI",
- )
- group.add_argument(
- "--simple_ddp",
- type=str2bool,
- default=False,
- )
- group.add_argument(
- "--num_worker_count",
- type=int,
- default=1,
- help="The number of machines on PAI.",
- )
- group.add_argument(
- "--access_key_id",
- type=str,
- default=None,
- help="The username for oss.",
- )
- group.add_argument(
- "--access_key_secret",
- type=str,
- default=None,
- help="The password for oss.",
- )
- group.add_argument(
- "--endpoint",
- type=str,
- default=None,
- help="The endpoint for oss.",
- )
- group.add_argument(
- "--bucket_name",
- type=str,
- default=None,
- help="The bucket name for oss.",
- )
- group.add_argument(
- "--oss_bucket",
- default=None,
- help="oss bucket.",
- )
- cls.trainer.add_arguments(parser)
- cls.add_task_arguments(parser)
- assert check_return_type(parser)
- return parser
- @classmethod
- def build_optimizers(
- cls,
- args: argparse.Namespace,
- model: torch.nn.Module,
- ) -> List[torch.optim.Optimizer]:
- if cls.num_optimizers != 1:
- raise RuntimeError(
- "build_optimizers() must be overridden if num_optimizers != 1"
- )
- optim_class = optim_classes.get(args.optim)
- if optim_class is None:
- raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
- if args.sharded_ddp:
- if fairscale is None:
- raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'")
- optim = fairscale.optim.oss.OSS(
- params=model.parameters(), optim=optim_class, **args.optim_conf
- )
- else:
- optim = optim_class(model.parameters(), **args.optim_conf)
- optimizers = [optim]
- return optimizers
- @classmethod
- def exclude_opts(cls) -> Tuple[str, ...]:
- """The options not to be shown by --print_config"""
- return "required", "print_config", "config", "ngpu"
- @classmethod
- def get_default_config(cls) -> Dict[str, Any]:
- """Return the configuration as dict.
- This method is used by print_config()
- """
- def get_class_type(name: str, classes: dict):
- _cls = classes.get(name)
- if _cls is None:
- raise ValueError(f"must be one of {list(classes)}: {name}")
- return _cls
- # This method is used only for --print_config
- assert check_argument_types()
- parser = cls.get_parser()
- args, _ = parser.parse_known_args()
- config = vars(args)
- # Excludes the options not to be shown
- for k in AbsTask.exclude_opts():
- config.pop(k)
- for i in range(1, cls.num_optimizers + 1):
- suf = "" if i == 1 else str(i)
- name = config[f"optim{suf}"]
- optim_class = get_class_type(name, optim_classes)
- conf = get_default_kwargs(optim_class)
- # Overwrite the default by the arguments,
- conf.update(config[f"optim{suf}_conf"])
- # and set it again
- config[f"optim{suf}_conf"] = conf
- name = config[f"scheduler{suf}"]
- if name is not None:
- scheduler_class = get_class_type(name, scheduler_classes)
- conf = get_default_kwargs(scheduler_class)
- # Overwrite the default by the arguments,
- conf.update(config[f"scheduler{suf}_conf"])
- # and set it again
- config[f"scheduler{suf}_conf"] = conf
- for class_choices in cls.class_choices_list:
- if getattr(args, class_choices.name) is not None:
- class_obj = class_choices.get_class(getattr(args, class_choices.name))
- conf = get_default_kwargs(class_obj)
- name = class_choices.name
- # Overwrite the default by the arguments,
- conf.update(config[f"{name}_conf"])
- # and set it again
- config[f"{name}_conf"] = conf
- return config
- @classmethod
- def check_required_command_args(cls, args: argparse.Namespace):
- assert check_argument_types()
- if hasattr(args, "required"):
- for k in vars(args):
- if "-" in k:
- raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
- required = ", ".join(
- f"--{a}" for a in args.required if getattr(args, a) is None
- )
- if len(required) != 0:
- parser = cls.get_parser()
- parser.print_help(file=sys.stderr)
- p = Path(sys.argv[0]).name
- print(file=sys.stderr)
- print(
- f"{p}: error: the following arguments are required: " f"{required}",
- file=sys.stderr,
- )
- sys.exit(2)
- @classmethod
- def check_task_requirements(
- cls,
- dataset: Union[AbsDataset, IterableESPnetDataset],
- allow_variable_data_keys: bool,
- train: bool,
- inference: bool = False,
- ) -> None:
- """Check if the dataset satisfy the requirement of current Task"""
- assert check_argument_types()
- mes = (
- f"If you intend to use an additional input, modify "
- f'"{cls.__name__}.required_data_names()" or '
- f'"{cls.__name__}.optional_data_names()". '
- f"Otherwise you need to set --allow_variable_data_keys true "
- )
- for k in cls.required_data_names(train, inference):
- if not dataset.has_name(k):
- raise RuntimeError(
- f'"{cls.required_data_names(train, inference)}" are required for'
- f' {cls.__name__}. but "{dataset.names()}" are input.\n{mes}'
- )
- if not allow_variable_data_keys:
- task_keys = cls.required_data_names(
- train, inference
- ) + cls.optional_data_names(train, inference)
- for k in dataset.names():
- if k not in task_keys:
- raise RuntimeError(
- f"The data-name must be one of {task_keys} "
- f'for {cls.__name__}: "{k}" is not allowed.\n{mes}'
- )
- @classmethod
- def print_config(cls, file=sys.stdout) -> None:
- assert check_argument_types()
- # Shows the config: e.g. python train.py asr --print_config
- config = cls.get_default_config()
- file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False))
- @classmethod
- def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
- assert check_argument_types()
- print(get_commandline_args(), file=sys.stderr)
- if args is None:
- parser = cls.get_parser()
- args = parser.parse_args(cmd)
- args.version = __version__
- if args.pretrain_path is not None:
- raise RuntimeError("--pretrain_path is deprecated. Use --init_param")
- if args.print_config:
- cls.print_config()
- sys.exit(0)
- cls.check_required_command_args(args)
- if not args.distributed or not args.multiprocessing_distributed:
- cls.main_worker(args)
- else:
- assert args.ngpu > 1
- cls.main_worker(args)
- @classmethod
- def run(cls):
- assert hasattr(cls, "finetune_args")
- args = cls.finetune_args
- args.train_shape_file = None
- if args.distributed:
- args.simple_ddp = True
- else:
- args.simple_ddp = False
- args.ngpu = 1
- args.use_pai = False
- args.batch_type = "length"
- args.oss_bucket = None
- args.input_size = None
- cls.main_worker(args)
- @classmethod
- def main_worker(cls, args: argparse.Namespace):
- assert check_argument_types()
- # 0. Init distributed process
- distributed_option = build_dataclass(DistributedOption, args)
- # Setting distributed_option.dist_rank, etc.
- if args.use_pai:
- distributed_option.init_options_pai()
- elif not args.simple_ddp:
- distributed_option.init_options()
- # Invoking torch.distributed.init_process_group
- if args.use_pai:
- distributed_option.init_torch_distributed_pai(args)
- elif not args.simple_ddp:
- distributed_option.init_torch_distributed(args)
- elif args.distributed and args.simple_ddp:
- distributed_option.init_torch_distributed_pai(args)
- args.ngpu = dist.get_world_size()
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
- if args.train_shape_file is None and args.dataset_type == "small":
- if not args.simple_ddp or distributed_option.dist_rank == 0:
- calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
- calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
- if args.simple_ddp:
- dist.barrier()
- args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
- args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
- if args.train_data_file is None and args.dataset_type == "large":
- if not args.simple_ddp or distributed_option.dist_rank == 0:
- generate_data_list(args.data_dir, args.train_set)
- generate_data_list(args.data_dir, args.dev_set)
- if args.simple_ddp:
- dist.barrier()
- args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
- args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
- # NOTE(kamo): Don't use logging before invoking logging.basicConfig()
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- if not distributed_option.distributed:
- _rank = ""
- else:
- _rank = (
- f":{distributed_option.dist_rank}/"
- f"{distributed_option.dist_world_size}"
- )
- # NOTE(kamo):
- # logging.basicConfig() is invoked in main_worker() instead of main()
- # because it can be invoked only once in a process.
- # FIXME(kamo): Should we use logging.getLogger()?
- logging.basicConfig(
- level=args.log_level,
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- else:
- # Suppress logging if RANK != 0
- logging.basicConfig(
- level="ERROR",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
- logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
- distributed_option.dist_rank,
- distributed_option.local_rank))
- # 1. Set random-seed
- set_all_random_seed(args.seed)
- torch.backends.cudnn.enabled = args.cudnn_enabled
- torch.backends.cudnn.benchmark = args.cudnn_benchmark
- torch.backends.cudnn.deterministic = args.cudnn_deterministic
- if args.detect_anomaly:
- logging.info("Invoking torch.autograd.set_detect_anomaly(True)")
- torch.autograd.set_detect_anomaly(args.detect_anomaly)
- # 2. Build model
- model = cls.build_model(args=args)
- if not isinstance(model, AbsESPnetModel):
- raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
- )
- model = model.to(
- dtype=getattr(torch, args.train_dtype),
- device="cuda" if args.ngpu > 0 else "cpu",
- )
- for t in args.freeze_param:
- for k, p in model.named_parameters():
- if k.startswith(t + ".") or k == t:
- logging.info(f"Setting {k}.requires_grad = False")
- p.requires_grad = False
- # 3. Build optimizer
- optimizers = cls.build_optimizers(args, model=model)
- # 4. Build schedulers
- schedulers = []
- for i, optim in enumerate(optimizers, 1):
- suf = "" if i == 1 else str(i)
- name = getattr(args, f"scheduler{suf}")
- conf = getattr(args, f"scheduler{suf}_conf")
- if name is not None:
- cls_ = scheduler_classes.get(name)
- if cls_ is None:
- raise ValueError(
- f"must be one of {list(scheduler_classes)}: {name}"
- )
- scheduler = cls_(optim, **conf)
- else:
- scheduler = None
- schedulers.append(scheduler)
- logging.info(pytorch_cudnn_version())
- logging.info(model_summary(model))
- for i, (o, s) in enumerate(zip(optimizers, schedulers), 1):
- suf = "" if i == 1 else str(i)
- logging.info(f"Optimizer{suf}:\n{o}")
- logging.info(f"Scheduler{suf}: {s}")
- # 5. Dump "args" to config.yaml
- # NOTE(kamo): "args" should be saved after object-buildings are done
- # because they are allowed to modify "args".
- output_dir = Path(args.output_dir)
- if not distributed_option.distributed or distributed_option.dist_rank == 0:
- output_dir.mkdir(parents=True, exist_ok=True)
- with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
- logging.info(
- f'Saving the configuration in {output_dir / "config.yaml"}'
- )
- if args.use_pai:
- buffer = BytesIO()
- torch.save({"config": vars(args)}, buffer)
- args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
- else:
- yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
- if args.dry_run:
- pass
- else:
- logging.info("Training args: {}".format(args))
- # 6. Loads pre-trained model
- for p in args.init_param:
- logging.info(f"Loading pretrained params from {p}")
- load_pretrained_model(
- model=model,
- init_param=p,
- ignore_init_mismatch=args.ignore_init_mismatch,
- # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
- # in PyTorch<=1.4
- map_location=f"cuda:{torch.cuda.current_device()}"
- if args.ngpu > 0
- else "cpu",
- oss_bucket=args.oss_bucket,
- )
- # 7. Build iterator factories
- if args.dataset_type == "large":
- from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
- train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
- seg_dict_file=args.seg_dict_file if hasattr(args,
- "seg_dict_file") else None,
- mode="train")
- valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
- seg_dict_file=args.seg_dict_file if hasattr(args,
- "seg_dict_file") else None,
- mode="eval")
- elif args.dataset_type == "small":
- train_iter_factory = cls.build_iter_factory(
- args=args,
- distributed_option=distributed_option,
- mode="train",
- )
- valid_iter_factory = cls.build_iter_factory(
- args=args,
- distributed_option=distributed_option,
- mode="valid",
- )
- else:
- raise ValueError(f"Not supported dataset_type={args.dataset_type}")
- if args.scheduler == "tri_stage":
- for scheduler in schedulers:
- scheduler.init_tri_stage_scheudler(max_update=args.max_update)
- # 8. Start training
- if args.use_wandb:
- if wandb is None:
- raise RuntimeError("Please install wandb")
- try:
- wandb.login()
- except wandb.errors.UsageError:
- logging.info("wandb not configured! run `wandb login` to enable")
- args.use_wandb = False
- if args.use_wandb:
- if (
- not distributed_option.distributed
- or distributed_option.dist_rank == 0
- ):
- if args.wandb_project is None:
- project = "FunASR_" + cls.__name__
- else:
- project = args.wandb_project
- if args.wandb_name is None:
- name = str(Path(".").resolve()).replace("/", "_")
- else:
- name = args.wandb_name
- wandb.init(
- entity=args.wandb_entity,
- project=project,
- name=name,
- dir=output_dir,
- id=args.wandb_id,
- resume="allow",
- )
- wandb.config.update(args)
- else:
- # wandb also supports grouping for distributed training,
- # but we only logs aggregated data,
- # so it's enough to perform on rank0 node.
- args.use_wandb = False
- # Don't give args to trainer.run() directly!!!
- # Instead of it, define "Options" object and build here.
- trainer_options = cls.trainer.build_options(args)
- cls.trainer.run(
- model=model,
- optimizers=optimizers,
- schedulers=schedulers,
- train_iter_factory=train_iter_factory,
- valid_iter_factory=valid_iter_factory,
- trainer_options=trainer_options,
- distributed_option=distributed_option,
- )
- if args.use_wandb and wandb.run:
- wandb.finish()
- @classmethod
- def build_iter_options(
- cls,
- args: argparse.Namespace,
- distributed_option: DistributedOption,
- mode: str,
- ):
- if mode == "train":
- preprocess_fn = cls.build_preprocess_fn(args, train=True)
- collate_fn = cls.build_collate_fn(args, train=True)
- data_path_and_name_and_type = args.train_data_path_and_name_and_type
- shape_files = args.train_shape_file
- batch_size = args.batch_size
- batch_bins = args.batch_bins
- batch_type = args.batch_type
- max_cache_size = args.max_cache_size
- max_cache_fd = args.max_cache_fd
- distributed = distributed_option.distributed
- num_batches = None
- num_iters_per_epoch = args.num_iters_per_epoch
- train = True
- elif mode == "valid":
- preprocess_fn = cls.build_preprocess_fn(args, train=False)
- collate_fn = cls.build_collate_fn(args, train=False)
- data_path_and_name_and_type = args.valid_data_path_and_name_and_type
- shape_files = args.valid_shape_file
- if args.valid_batch_type is None:
- batch_type = args.batch_type
- else:
- batch_type = args.valid_batch_type
- if args.valid_batch_size is None:
- batch_size = args.batch_size
- else:
- batch_size = args.valid_batch_size
- if args.valid_batch_bins is None:
- batch_bins = args.batch_bins
- else:
- batch_bins = args.valid_batch_bins
- if args.valid_max_cache_size is None:
- # Cache 5% of maximum size for validation loader
- max_cache_size = 0.05 * args.max_cache_size
- else:
- max_cache_size = args.valid_max_cache_size
- max_cache_fd = args.max_cache_fd
- distributed = distributed_option.distributed
- num_batches = None
- num_iters_per_epoch = None
- train = False
- else:
- raise NotImplementedError(f"mode={mode}")
- return IteratorOptions(
- preprocess_fn=preprocess_fn,
- collate_fn=collate_fn,
- data_path_and_name_and_type=data_path_and_name_and_type,
- shape_files=shape_files,
- batch_type=batch_type,
- batch_size=batch_size,
- batch_bins=batch_bins,
- num_batches=num_batches,
- max_cache_size=max_cache_size,
- max_cache_fd=max_cache_fd,
- distributed=distributed,
- num_iters_per_epoch=num_iters_per_epoch,
- train=train,
- )
- @classmethod
- def build_iter_factory(
- cls,
- args: argparse.Namespace,
- distributed_option: DistributedOption,
- mode: str,
- kwargs: dict = None,
- ) -> AbsIterFactory:
- """Build a factory object of mini-batch iterator.
- This object is invoked at every epochs to build the iterator for each epoch
- as following:
- >>> iter_factory = cls.build_iter_factory(...)
- >>> for epoch in range(1, max_epoch):
- ... for keys, batch in iter_fatory.build_iter(epoch):
- ... model(**batch)
- The mini-batches for each epochs are fully controlled by this class.
- Note that the random seed used for shuffling is decided as "seed + epoch" and
- the generated mini-batches can be reproduces when resuming.
- Note that the definition of "epoch" doesn't always indicate
- to run out of the whole training corpus.
- "--num_iters_per_epoch" option restricts the number of iterations for each epoch
- and the rest of samples for the originally epoch are left for the next epoch.
- e.g. If The number of mini-batches equals to 4, the following two are same:
- - 1 epoch without "--num_iters_per_epoch"
- - 4 epoch with "--num_iters_per_epoch" == 4
- """
- assert check_argument_types()
- iter_options = cls.build_iter_options(args, distributed_option, mode)
- # Overwrite iter_options if any kwargs is given
- if kwargs is not None:
- for k, v in kwargs.items():
- setattr(iter_options, k, v)
- if args.iterator_type == "sequence":
- return cls.build_sequence_iter_factory(
- args=args,
- iter_options=iter_options,
- mode=mode,
- )
- elif args.iterator_type == "chunk":
- return cls.build_chunk_iter_factory(
- args=args,
- iter_options=iter_options,
- mode=mode,
- )
- elif args.iterator_type == "task":
- return cls.build_task_iter_factory(
- args=args,
- iter_options=iter_options,
- mode=mode,
- )
- else:
- raise RuntimeError(f"Not supported: iterator_type={args.iterator_type}")
- @classmethod
- def build_sequence_iter_factory(
- cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str
- ) -> AbsIterFactory:
- assert check_argument_types()
- dataset = ESPnetDataset(
- iter_options.data_path_and_name_and_type,
- float_dtype=args.train_dtype,
- preprocess=iter_options.preprocess_fn,
- max_cache_size=iter_options.max_cache_size,
- max_cache_fd=iter_options.max_cache_fd,
- )
- cls.check_task_requirements(
- dataset, args.allow_variable_data_keys, train=iter_options.train
- )
- if Path(
- Path(iter_options.data_path_and_name_and_type[0][0]).parent, "utt2category"
- ).exists():
- utt2category_file = str(
- Path(
- Path(iter_options.data_path_and_name_and_type[0][0]).parent,
- "utt2category",
- )
- )
- else:
- utt2category_file = None
- batch_sampler = build_batch_sampler(
- type=iter_options.batch_type,
- shape_files=iter_options.shape_files,
- fold_lengths=args.fold_length,
- batch_size=iter_options.batch_size,
- batch_bins=iter_options.batch_bins,
- sort_in_batch=args.sort_in_batch,
- sort_batch=args.sort_batch,
- drop_last=False,
- min_batch_size=torch.distributed.get_world_size()
- if iter_options.distributed
- else 1,
- utt2category_file=utt2category_file,
- )
- batches = list(batch_sampler)
- if iter_options.num_batches is not None:
- batches = batches[: iter_options.num_batches]
- bs_list = [len(batch) for batch in batches]
- logging.info(f"[{mode}] dataset:\n{dataset}")
- logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
- logging.info(
- f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
- f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
- )
- if args.scheduler == "tri_stage" and mode == "train":
- args.max_update = len(bs_list) * args.max_epoch
- logging.info("Max update: {}".format(args.max_update))
- if iter_options.distributed:
- world_size = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
- for batch in batches:
- if len(batch) < world_size:
- raise RuntimeError(
- f"The batch-size must be equal or more than world_size: "
- f"{len(batch)} < {world_size}"
- )
- batches = [batch[rank::world_size] for batch in batches]
- return SequenceIterFactory(
- dataset=dataset,
- batches=batches,
- seed=args.seed,
- num_iters_per_epoch=iter_options.num_iters_per_epoch,
- shuffle=iter_options.train,
- num_workers=args.num_workers,
- collate_fn=iter_options.collate_fn,
- pin_memory=args.ngpu > 0,
- )
- @classmethod
- def build_chunk_iter_factory(
- cls,
- args: argparse.Namespace,
- iter_options: IteratorOptions,
- mode: str,
- ) -> AbsIterFactory:
- assert check_argument_types()
- dataset = ESPnetDataset(
- iter_options.data_path_and_name_and_type,
- float_dtype=args.train_dtype,
- preprocess=iter_options.preprocess_fn,
- max_cache_size=iter_options.max_cache_size,
- max_cache_fd=iter_options.max_cache_fd,
- )
- cls.check_task_requirements(
- dataset, args.allow_variable_data_keys, train=iter_options.train
- )
- if len(iter_options.shape_files) == 0:
- key_file = iter_options.data_path_and_name_and_type[0][0]
- else:
- key_file = iter_options.shape_files[0]
- batch_sampler = UnsortedBatchSampler(batch_size=1, key_file=key_file)
- batches = list(batch_sampler)
- if iter_options.num_batches is not None:
- batches = batches[: iter_options.num_batches]
- logging.info(f"[{mode}] dataset:\n{dataset}")
- if iter_options.distributed:
- world_size = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
- if len(batches) < world_size:
- raise RuntimeError("Number of samples is smaller than world_size")
- if iter_options.batch_size < world_size:
- raise RuntimeError("batch_size must be equal or more than world_size")
- if rank < iter_options.batch_size % world_size:
- batch_size = iter_options.batch_size // world_size + 1
- else:
- batch_size = iter_options.batch_size // world_size
- num_cache_chunks = args.num_cache_chunks // world_size
- # NOTE(kamo): Split whole corpus by sample numbers without considering
- # each of the lengths, therefore the number of iteration counts are not
- # always equal to each other and the iterations are limitted
- # by the fewest iterations.
- # i.e. the samples over the counts are discarded.
- batches = batches[rank::world_size]
- else:
- batch_size = iter_options.batch_size
- num_cache_chunks = args.num_cache_chunks
- return ChunkIterFactory(
- dataset=dataset,
- batches=batches,
- seed=args.seed,
- batch_size=batch_size,
- # For chunk iterator,
- # --num_iters_per_epoch doesn't indicate the number of iterations,
- # but indicates the number of samples.
- num_samples_per_epoch=iter_options.num_iters_per_epoch,
- shuffle=iter_options.train,
- num_workers=args.num_workers,
- collate_fn=iter_options.collate_fn,
- pin_memory=args.ngpu > 0,
- chunk_length=args.chunk_length,
- chunk_shift_ratio=args.chunk_shift_ratio,
- num_cache_chunks=num_cache_chunks,
- )
- # NOTE(kamo): Not abstract class
- @classmethod
- def build_task_iter_factory(
- cls,
- args: argparse.Namespace,
- iter_options: IteratorOptions,
- mode: str,
- ) -> AbsIterFactory:
- """Build task specific iterator factory
- Example:
- >>> class YourTask(AbsTask):
- ... @classmethod
- ... def add_task_arguments(cls, parser: argparse.ArgumentParser):
- ... parser.set_defaults(iterator_type="task")
- ...
- ... @classmethod
- ... def build_task_iter_factory(
- ... cls,
- ... args: argparse.Namespace,
- ... iter_options: IteratorOptions,
- ... mode: str,
- ... ):
- ... return FooIterFactory(...)
- ...
- ... @classmethod
- ... def build_iter_options(
- .... args: argparse.Namespace,
- ... distributed_option: DistributedOption,
- ... mode: str
- ... ):
- ... # if you need to customize options object
- """
- raise NotImplementedError
- @classmethod
- def build_multiple_iter_factory(
- cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str
- ):
- assert check_argument_types()
- iter_options = cls.build_iter_options(args, distributed_option, mode)
- assert len(iter_options.data_path_and_name_and_type) > 0, len(
- iter_options.data_path_and_name_and_type
- )
- # 1. Sanity check
- num_splits = None
- for path in [
- path for path, _, _ in iter_options.data_path_and_name_and_type
- ] + list(iter_options.shape_files):
- if not Path(path).is_dir():
- raise RuntimeError(f"{path} is not a directory")
- p = Path(path) / "num_splits"
- if not p.exists():
- raise FileNotFoundError(f"{p} is not found")
- with p.open() as f:
- _num_splits = int(f.read())
- if num_splits is not None and num_splits != _num_splits:
- raise RuntimeError(
- f"Number of splits are mismathed: "
- f"{iter_options.data_path_and_name_and_type[0][0]} and {path}"
- )
- num_splits = _num_splits
- for i in range(num_splits):
- p = Path(path) / f"split.{i}"
- if not p.exists():
- raise FileNotFoundError(f"{p} is not found")
- # 2. Create functions to build an iter factory for each splits
- data_path_and_name_and_type_list = [
- [
- (str(Path(p) / f"split.{i}"), n, t)
- for p, n, t in iter_options.data_path_and_name_and_type
- ]
- for i in range(num_splits)
- ]
- shape_files_list = [
- [str(Path(s) / f"split.{i}") for s in iter_options.shape_files]
- for i in range(num_splits)
- ]
- num_iters_per_epoch_list = [
- (iter_options.num_iters_per_epoch + i) // num_splits
- if iter_options.num_iters_per_epoch is not None
- else None
- for i in range(num_splits)
- ]
- max_cache_size = iter_options.max_cache_size / num_splits
- # Note that iter-factories are built for each epoch at runtime lazily.
- build_funcs = [
- functools.partial(
- cls.build_iter_factory,
- args,
- distributed_option,
- mode,
- kwargs=dict(
- data_path_and_name_and_type=_data_path_and_name_and_type,
- shape_files=_shape_files,
- num_iters_per_epoch=_num_iters_per_epoch,
- max_cache_size=max_cache_size,
- ),
- )
- for (
- _data_path_and_name_and_type,
- _shape_files,
- _num_iters_per_epoch,
- ) in zip(
- data_path_and_name_and_type_list,
- shape_files_list,
- num_iters_per_epoch_list,
- )
- ]
- # 3. Build MultipleIterFactory
- return MultipleIterFactory(
- build_funcs=build_funcs, shuffle=iter_options.train, seed=args.seed
- )
- @classmethod
- def build_streaming_iterator(
- cls,
- data_path_and_name_and_type,
- preprocess_fn,
- collate_fn,
- key_file: str = None,
- batch_size: int = 1,
- dtype: str = np.float32,
- num_workers: int = 1,
- allow_variable_data_keys: bool = False,
- ngpu: int = 0,
- inference: bool = False,
- ) -> DataLoader:
- """Build DataLoader using iterable dataset"""
- assert check_argument_types()
- # For backward compatibility for pytorch DataLoader
- if collate_fn is not None:
- kwargs = dict(collate_fn=collate_fn)
- else:
- kwargs = {}
- dataset = IterableESPnetDataset(
- data_path_and_name_and_type,
- float_dtype=dtype,
- preprocess=preprocess_fn,
- key_file=key_file,
- )
- if dataset.apply_utt2category:
- kwargs.update(batch_size=1)
- else:
- kwargs.update(batch_size=batch_size)
- cls.check_task_requirements(
- dataset, allow_variable_data_keys, train=False, inference=inference
- )
- return DataLoader(
- dataset=dataset,
- pin_memory=ngpu > 0,
- num_workers=num_workers,
- **kwargs,
- )
- # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
- @classmethod
- def build_model_from_file(
- cls,
- config_file: Union[Path, str] = None,
- model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- device: str = "cpu",
- ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
- """Build model from the files.
- This method is used for inference or fine-tuning.
- Args:
- config_file: The yaml file saved when training.
- model_file: The model file saved when training.
- device: Device type, "cpu", "cuda", or "cuda:N".
- """
- assert check_argument_types()
- if config_file is None:
- assert model_file is not None, (
- "The argument 'model_file' must be provided "
- "if the argument 'config_file' is not specified."
- )
- config_file = Path(model_file).parent / "config.yaml"
- else:
- config_file = Path(config_file)
- with config_file.open("r", encoding="utf-8") as f:
- args = yaml.safe_load(f)
- if cmvn_file is not None:
- args["cmvn_file"] = cmvn_file
- args = argparse.Namespace(**args)
- model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
- raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
- )
- model.to(device)
- if model_file is not None:
- if device == "cuda":
- # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
- # in PyTorch<=1.4
- device = f"cuda:{torch.cuda.current_device()}"
- model.load_state_dict(torch.load(model_file, map_location=device))
- model.to(device)
- return model, args
|