abs_task.py 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966
  1. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  2. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  3. """Abstract task module."""
  4. import argparse
  5. import functools
  6. import logging
  7. import os
  8. import sys
  9. from abc import ABC
  10. from abc import abstractmethod
  11. from dataclasses import dataclass
  12. from distutils.version import LooseVersion
  13. from io import BytesIO
  14. from pathlib import Path
  15. from typing import Any
  16. from typing import Callable
  17. from typing import Dict
  18. from typing import List
  19. from typing import Optional
  20. from typing import Sequence
  21. from typing import Tuple
  22. from typing import Union
  23. import humanfriendly
  24. import numpy as np
  25. import torch
  26. import torch.distributed as dist
  27. import torch.multiprocessing
  28. import torch.nn
  29. import torch.optim
  30. import yaml
  31. from funasr.models.base_model import FunASRModel
  32. from torch.utils.data import DataLoader
  33. from typeguard import check_argument_types
  34. from typeguard import check_return_type
  35. from funasr import __version__
  36. from funasr.datasets.dataset import AbsDataset
  37. from funasr.datasets.dataset import DATA_TYPES
  38. from funasr.datasets.dataset import ESPnetDataset
  39. from funasr.datasets.iterable_dataset import IterableESPnetDataset
  40. from funasr.iterators.abs_iter_factory import AbsIterFactory
  41. from funasr.iterators.chunk_iter_factory import ChunkIterFactory
  42. from funasr.iterators.multiple_iter_factory import MultipleIterFactory
  43. from funasr.iterators.sequence_iter_factory import SequenceIterFactory
  44. from funasr.main_funcs.collect_stats import collect_stats
  45. from funasr.optimizers.fairseq_adam import FairseqAdam
  46. from funasr.optimizers.sgd import SGD
  47. from funasr.samplers.build_batch_sampler import BATCH_TYPES
  48. from funasr.samplers.build_batch_sampler import build_batch_sampler
  49. from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
  50. from funasr.schedulers.noam_lr import NoamLR
  51. from funasr.schedulers.tri_stage_scheduler import TriStageLR
  52. from funasr.schedulers.warmup_lr import WarmupLR
  53. from funasr.torch_utils.load_pretrained_model import load_pretrained_model
  54. from funasr.torch_utils.model_summary import model_summary
  55. from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
  56. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  57. from funasr.train.class_choices import ClassChoices
  58. from funasr.train.distributed_utils import DistributedOption
  59. from funasr.train.trainer import Trainer
  60. from funasr.utils import config_argparse
  61. from funasr.utils.build_dataclass import build_dataclass
  62. from funasr.utils.cli_utils import get_commandline_args
  63. from funasr.utils.get_default_kwargs import get_default_kwargs
  64. from funasr.utils.nested_dict_action import NestedDictAction
  65. from funasr.utils.types import humanfriendly_parse_size_or_none
  66. from funasr.utils.types import int_or_none
  67. from funasr.utils.types import str2bool
  68. from funasr.utils.types import str2triple_str
  69. from funasr.utils.types import str_or_int
  70. from funasr.utils.types import str_or_none
  71. from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text
  72. from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
  73. try:
  74. import wandb
  75. except Exception:
  76. wandb = None
  77. if LooseVersion(torch.__version__) >= LooseVersion("1.5.0"):
  78. pass
  79. else:
  80. pass
  81. optim_classes = dict(
  82. adam=torch.optim.Adam,
  83. fairseq_adam=FairseqAdam,
  84. adamw=torch.optim.AdamW,
  85. sgd=SGD,
  86. adadelta=torch.optim.Adadelta,
  87. adagrad=torch.optim.Adagrad,
  88. adamax=torch.optim.Adamax,
  89. asgd=torch.optim.ASGD,
  90. lbfgs=torch.optim.LBFGS,
  91. rmsprop=torch.optim.RMSprop,
  92. rprop=torch.optim.Rprop,
  93. )
  94. if LooseVersion(torch.__version__) >= LooseVersion("1.10.0"):
  95. # From 1.10.0, RAdam is officially supported
  96. optim_classes.update(
  97. radam=torch.optim.RAdam,
  98. )
  99. try:
  100. import torch_optimizer
  101. optim_classes.update(
  102. accagd=torch_optimizer.AccSGD,
  103. adabound=torch_optimizer.AdaBound,
  104. adamod=torch_optimizer.AdaMod,
  105. diffgrad=torch_optimizer.DiffGrad,
  106. lamb=torch_optimizer.Lamb,
  107. novograd=torch_optimizer.NovoGrad,
  108. pid=torch_optimizer.PID,
  109. # torch_optimizer<=0.0.1a10 doesn't support
  110. # qhadam=torch_optimizer.QHAdam,
  111. qhm=torch_optimizer.QHM,
  112. sgdw=torch_optimizer.SGDW,
  113. yogi=torch_optimizer.Yogi,
  114. )
  115. if LooseVersion(torch_optimizer.__version__) < LooseVersion("0.2.0"):
  116. # From 0.2.0, RAdam is dropped
  117. optim_classes.update(
  118. radam=torch_optimizer.RAdam,
  119. )
  120. del torch_optimizer
  121. except ImportError:
  122. pass
  123. try:
  124. import apex
  125. optim_classes.update(
  126. fusedadam=apex.optimizers.FusedAdam,
  127. fusedlamb=apex.optimizers.FusedLAMB,
  128. fusednovograd=apex.optimizers.FusedNovoGrad,
  129. fusedsgd=apex.optimizers.FusedSGD,
  130. )
  131. del apex
  132. except ImportError:
  133. pass
  134. try:
  135. import fairscale
  136. except ImportError:
  137. fairscale = None
  138. scheduler_classes = dict(
  139. ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
  140. lambdalr=torch.optim.lr_scheduler.LambdaLR,
  141. steplr=torch.optim.lr_scheduler.StepLR,
  142. multisteplr=torch.optim.lr_scheduler.MultiStepLR,
  143. exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
  144. CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
  145. noamlr=NoamLR,
  146. warmuplr=WarmupLR,
  147. tri_stage=TriStageLR,
  148. cycliclr=torch.optim.lr_scheduler.CyclicLR,
  149. onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
  150. CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
  151. )
  152. # To lower keys
  153. optim_classes = {k.lower(): v for k, v in optim_classes.items()}
  154. scheduler_classes = {k.lower(): v for k, v in scheduler_classes.items()}
  155. @dataclass
  156. class IteratorOptions:
  157. preprocess_fn: callable
  158. collate_fn: callable
  159. data_path_and_name_and_type: list
  160. shape_files: list
  161. batch_size: int
  162. batch_bins: int
  163. batch_type: str
  164. max_cache_size: float
  165. max_cache_fd: int
  166. distributed: bool
  167. num_batches: Optional[int]
  168. num_iters_per_epoch: Optional[int]
  169. train: bool
  170. class AbsTask(ABC):
  171. # Use @staticmethod, or @classmethod,
  172. # instead of instance method to avoid God classes
  173. # If you need more than one optimizers, change this value in inheritance
  174. num_optimizers: int = 1
  175. trainer = Trainer
  176. class_choices_list: List[ClassChoices] = []
  177. finetune_args: None
  178. def __init__(self):
  179. raise RuntimeError("This class can't be instantiated.")
  180. @classmethod
  181. @abstractmethod
  182. def add_task_arguments(cls, parser: argparse.ArgumentParser):
  183. pass
  184. @classmethod
  185. @abstractmethod
  186. def build_collate_fn(
  187. cls, args: argparse.Namespace, train: bool
  188. ) -> Callable[[Sequence[Dict[str, np.ndarray]]], Dict[str, torch.Tensor]]:
  189. """Return "collate_fn", which is a callable object and given to DataLoader.
  190. >>> from torch.utils.data import DataLoader
  191. >>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...)
  192. In many cases, you can use our common collate_fn.
  193. """
  194. raise NotImplementedError
  195. @classmethod
  196. @abstractmethod
  197. def build_preprocess_fn(
  198. cls, args: argparse.Namespace, train: bool
  199. ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
  200. raise NotImplementedError
  201. @classmethod
  202. @abstractmethod
  203. def required_data_names(
  204. cls, train: bool = True, inference: bool = False
  205. ) -> Tuple[str, ...]:
  206. """Define the required names by Task
  207. This function is used by
  208. >>> cls.check_task_requirements()
  209. If your model is defined as following,
  210. >>> from funasr.models.base_model import FunASRModel
  211. >>> class Model(FunASRModel):
  212. ... def forward(self, input, output, opt=None): pass
  213. then "required_data_names" should be as
  214. >>> required_data_names = ('input', 'output')
  215. """
  216. raise NotImplementedError
  217. @classmethod
  218. @abstractmethod
  219. def optional_data_names(
  220. cls, train: bool = True, inference: bool = False
  221. ) -> Tuple[str, ...]:
  222. """Define the optional names by Task
  223. This function is used by
  224. >>> cls.check_task_requirements()
  225. If your model is defined as follows,
  226. >>> from funasr.models.base_model import FunASRModel
  227. >>> class Model(FunASRModel):
  228. ... def forward(self, input, output, opt=None): pass
  229. then "optional_data_names" should be as
  230. >>> optional_data_names = ('opt',)
  231. """
  232. raise NotImplementedError
  233. @classmethod
  234. @abstractmethod
  235. def build_model(cls, args: argparse.Namespace) -> FunASRModel:
  236. raise NotImplementedError
  237. @classmethod
  238. def get_parser(cls) -> config_argparse.ArgumentParser:
  239. assert check_argument_types()
  240. class ArgumentDefaultsRawTextHelpFormatter(
  241. argparse.RawTextHelpFormatter,
  242. argparse.ArgumentDefaultsHelpFormatter,
  243. ):
  244. pass
  245. parser = config_argparse.ArgumentParser(
  246. description="base parser",
  247. formatter_class=ArgumentDefaultsRawTextHelpFormatter,
  248. )
  249. # NOTE(kamo): Use '_' instead of '-' to avoid confusion.
  250. # I think '-' looks really confusing if it's written in yaml.
  251. # NOTE(kamo): add_arguments(..., required=True) can't be used
  252. # to provide --print_config mode. Instead of it, do as
  253. # parser.set_defaults(required=["output_dir"])
  254. group = parser.add_argument_group("Common configuration")
  255. group.add_argument(
  256. "--print_config",
  257. action="store_true",
  258. help="Print the config file and exit",
  259. )
  260. group.add_argument(
  261. "--log_level",
  262. type=lambda x: x.upper(),
  263. default="INFO",
  264. choices=("ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  265. help="The verbose level of logging",
  266. )
  267. group.add_argument(
  268. "--dry_run",
  269. type=str2bool,
  270. default=False,
  271. help="Perform process without training",
  272. )
  273. group.add_argument(
  274. "--iterator_type",
  275. type=str,
  276. choices=["sequence", "chunk", "task", "none"],
  277. default="sequence",
  278. help="Specify iterator type",
  279. )
  280. group.add_argument("--output_dir", type=str_or_none, default=None)
  281. group.add_argument(
  282. "--ngpu",
  283. type=int,
  284. default=0,
  285. help="The number of gpus. 0 indicates CPU mode",
  286. )
  287. group.add_argument("--seed", type=int, default=0, help="Random seed")
  288. group.add_argument(
  289. "--num_workers",
  290. type=int,
  291. default=1,
  292. help="The number of workers used for DataLoader",
  293. )
  294. group.add_argument(
  295. "--num_att_plot",
  296. type=int,
  297. default=3,
  298. help="The number images to plot the outputs from attention. "
  299. "This option makes sense only when attention-based model. "
  300. "We can also disable the attention plot by setting it 0",
  301. )
  302. group = parser.add_argument_group("distributed training related")
  303. group.add_argument(
  304. "--dist_backend",
  305. default="nccl",
  306. type=str,
  307. help="distributed backend",
  308. )
  309. group.add_argument(
  310. "--dist_init_method",
  311. type=str,
  312. default="env://",
  313. help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
  314. '"WORLD_SIZE", and "RANK" are referred.',
  315. )
  316. group.add_argument(
  317. "--dist_world_size",
  318. default=None,
  319. type=int_or_none,
  320. help="number of nodes for distributed training",
  321. )
  322. group.add_argument(
  323. "--dist_rank",
  324. type=int_or_none,
  325. default=None,
  326. help="node rank for distributed training",
  327. )
  328. group.add_argument(
  329. # Not starting with "dist_" for compatibility to launch.py
  330. "--local_rank",
  331. type=int_or_none,
  332. default=None,
  333. help="local rank for distributed training. This option is used if "
  334. "--multiprocessing_distributed=false",
  335. )
  336. group.add_argument(
  337. "--dist_master_addr",
  338. default=None,
  339. type=str_or_none,
  340. help="The master address for distributed training. "
  341. "This value is used when dist_init_method == 'env://'",
  342. )
  343. group.add_argument(
  344. "--dist_master_port",
  345. default=None,
  346. type=int_or_none,
  347. help="The master port for distributed training"
  348. "This value is used when dist_init_method == 'env://'",
  349. )
  350. group.add_argument(
  351. "--dist_launcher",
  352. default=None,
  353. type=str_or_none,
  354. choices=["slurm", "mpi", None],
  355. help="The launcher type for distributed training",
  356. )
  357. group.add_argument(
  358. "--multiprocessing_distributed",
  359. default=False,
  360. type=str2bool,
  361. help="Use multi-processing distributed training to launch "
  362. "N processes per node, which has N GPUs. This is the "
  363. "fastest way to use PyTorch for either single node or "
  364. "multi node data parallel training",
  365. )
  366. group.add_argument(
  367. "--unused_parameters",
  368. type=str2bool,
  369. default=False,
  370. help="Whether to use the find_unused_parameters in "
  371. "torch.nn.parallel.DistributedDataParallel ",
  372. )
  373. group.add_argument(
  374. "--sharded_ddp",
  375. default=False,
  376. type=str2bool,
  377. help="Enable sharded training provided by fairscale",
  378. )
  379. group = parser.add_argument_group("cudnn mode related")
  380. group.add_argument(
  381. "--cudnn_enabled",
  382. type=str2bool,
  383. default=torch.backends.cudnn.enabled,
  384. help="Enable CUDNN",
  385. )
  386. group.add_argument(
  387. "--cudnn_benchmark",
  388. type=str2bool,
  389. default=torch.backends.cudnn.benchmark,
  390. help="Enable cudnn-benchmark mode",
  391. )
  392. group.add_argument(
  393. "--cudnn_deterministic",
  394. type=str2bool,
  395. default=True,
  396. help="Enable cudnn-deterministic mode",
  397. )
  398. group = parser.add_argument_group("collect stats mode related")
  399. group.add_argument(
  400. "--collect_stats",
  401. type=str2bool,
  402. default=False,
  403. help='Perform on "collect stats" mode',
  404. )
  405. group.add_argument(
  406. "--mc",
  407. type=bool,
  408. default=False,
  409. help="MultiChannel input",
  410. )
  411. group.add_argument(
  412. "--write_collected_feats",
  413. type=str2bool,
  414. default=False,
  415. help='Write the output features from the model when "collect stats" mode',
  416. )
  417. group = parser.add_argument_group("Trainer related")
  418. group.add_argument(
  419. "--max_epoch",
  420. type=int,
  421. default=40,
  422. help="The maximum number epoch to train",
  423. )
  424. group.add_argument(
  425. "--max_update",
  426. type=int,
  427. default=sys.maxsize,
  428. help="The maximum number update step to train",
  429. )
  430. parser.add_argument(
  431. "--batch_interval",
  432. type=int,
  433. default=-1,
  434. help="The batch interval for saving model.",
  435. )
  436. group.add_argument(
  437. "--patience",
  438. type=int_or_none,
  439. default=None,
  440. help="Number of epochs to wait without improvement "
  441. "before stopping the training",
  442. )
  443. group.add_argument(
  444. "--val_scheduler_criterion",
  445. type=str,
  446. nargs=2,
  447. default=("valid", "loss"),
  448. help="The criterion used for the value given to the lr scheduler. "
  449. 'Give a pair referring the phase, "train" or "valid",'
  450. 'and the criterion name. The mode specifying "min" or "max" can '
  451. "be changed by --scheduler_conf",
  452. )
  453. group.add_argument(
  454. "--early_stopping_criterion",
  455. type=str,
  456. nargs=3,
  457. default=("valid", "loss", "min"),
  458. help="The criterion used for judging of early stopping. "
  459. 'Give a pair referring the phase, "train" or "valid",'
  460. 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
  461. )
  462. group.add_argument(
  463. "--best_model_criterion",
  464. type=str2triple_str,
  465. nargs="+",
  466. default=[
  467. ("train", "loss", "min"),
  468. ("valid", "loss", "min"),
  469. ("train", "acc", "max"),
  470. ("valid", "acc", "max"),
  471. ],
  472. help="The criterion used for judging of the best model. "
  473. 'Give a pair referring the phase, "train" or "valid",'
  474. 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
  475. )
  476. group.add_argument(
  477. "--keep_nbest_models",
  478. type=int,
  479. nargs="+",
  480. default=[10],
  481. help="Remove previous snapshots excluding the n-best scored epochs",
  482. )
  483. group.add_argument(
  484. "--nbest_averaging_interval",
  485. type=int,
  486. default=0,
  487. help="The epoch interval to apply model averaging and save nbest models",
  488. )
  489. group.add_argument(
  490. "--grad_clip",
  491. type=float,
  492. default=5.0,
  493. help="Gradient norm threshold to clip",
  494. )
  495. group.add_argument(
  496. "--grad_clip_type",
  497. type=float,
  498. default=2.0,
  499. help="The type of the used p-norm for gradient clip. Can be inf",
  500. )
  501. group.add_argument(
  502. "--grad_noise",
  503. type=str2bool,
  504. default=False,
  505. help="The flag to switch to use noise injection to "
  506. "gradients during training",
  507. )
  508. group.add_argument(
  509. "--accum_grad",
  510. type=int,
  511. default=1,
  512. help="The number of gradient accumulation",
  513. )
  514. group.add_argument(
  515. "--bias_grad_times",
  516. type=float,
  517. default=1.0,
  518. help="To scale the gradient of contextual related params",
  519. )
  520. group.add_argument(
  521. "--no_forward_run",
  522. type=str2bool,
  523. default=False,
  524. help="Just only iterating data loading without "
  525. "model forwarding and training",
  526. )
  527. group.add_argument(
  528. "--resume",
  529. type=str2bool,
  530. default=False,
  531. help="Enable resuming if checkpoint is existing",
  532. )
  533. group.add_argument(
  534. "--train_dtype",
  535. default="float32",
  536. choices=["float16", "float32", "float64"],
  537. help="Data type for training.",
  538. )
  539. group.add_argument(
  540. "--use_amp",
  541. type=str2bool,
  542. default=False,
  543. help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
  544. )
  545. group.add_argument(
  546. "--log_interval",
  547. type=int_or_none,
  548. default=None,
  549. help="Show the logs every the number iterations in each epochs at the "
  550. "training phase. If None is given, it is decided according the number "
  551. "of training samples automatically .",
  552. )
  553. group.add_argument(
  554. "--use_tensorboard",
  555. type=str2bool,
  556. default=True,
  557. help="Enable tensorboard logging",
  558. )
  559. group.add_argument(
  560. "--use_wandb",
  561. type=str2bool,
  562. default=False,
  563. help="Enable wandb logging",
  564. )
  565. group.add_argument(
  566. "--wandb_project",
  567. type=str,
  568. default=None,
  569. help="Specify wandb project",
  570. )
  571. group.add_argument(
  572. "--wandb_id",
  573. type=str,
  574. default=None,
  575. help="Specify wandb id",
  576. )
  577. group.add_argument(
  578. "--wandb_entity",
  579. type=str,
  580. default=None,
  581. help="Specify wandb entity",
  582. )
  583. group.add_argument(
  584. "--wandb_name",
  585. type=str,
  586. default=None,
  587. help="Specify wandb run name",
  588. )
  589. group.add_argument(
  590. "--wandb_model_log_interval",
  591. type=int,
  592. default=-1,
  593. help="Set the model log period",
  594. )
  595. group.add_argument(
  596. "--detect_anomaly",
  597. type=str2bool,
  598. default=False,
  599. help="Set torch.autograd.set_detect_anomaly",
  600. )
  601. group = parser.add_argument_group("Pretraining model related")
  602. group.add_argument("--pretrain_path", help="This option is obsoleted")
  603. group.add_argument(
  604. "--init_param",
  605. type=str,
  606. action="append",
  607. default=[],
  608. help="Specify the file path used for initialization of parameters. "
  609. "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
  610. "where file_path is the model file path, "
  611. "src_key specifies the key of model states to be used in the model file, "
  612. "dst_key specifies the attribute of the model to be initialized, "
  613. "and exclude_keys excludes keys of model states for the initialization."
  614. "e.g.\n"
  615. " # Load all parameters"
  616. " --init_param some/where/model.pb\n"
  617. " # Load only decoder parameters"
  618. " --init_param some/where/model.pb:decoder:decoder\n"
  619. " # Load only decoder parameters excluding decoder.embed"
  620. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
  621. " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
  622. )
  623. group.add_argument(
  624. "--ignore_init_mismatch",
  625. type=str2bool,
  626. default=False,
  627. help="Ignore size mismatch when loading pre-trained model",
  628. )
  629. group.add_argument(
  630. "--freeze_param",
  631. type=str,
  632. default=[],
  633. action="append",
  634. help="Freeze parameters",
  635. )
  636. group = parser.add_argument_group("BatchSampler related")
  637. group.add_argument(
  638. "--num_iters_per_epoch",
  639. type=int_or_none,
  640. default=None,
  641. help="Restrict the number of iterations for training per epoch",
  642. )
  643. group.add_argument(
  644. "--batch_size",
  645. type=int,
  646. default=20,
  647. help="The mini-batch size used for training. Used if batch_type='unsorted',"
  648. " 'sorted', or 'folded'.",
  649. )
  650. group.add_argument(
  651. "--valid_batch_size",
  652. type=int_or_none,
  653. default=None,
  654. help="If not given, the value of --batch_size is used",
  655. )
  656. group.add_argument(
  657. "--batch_bins",
  658. type=int,
  659. default=1000000,
  660. help="The number of batch bins. Used if batch_type='length' or 'numel'",
  661. )
  662. group.add_argument(
  663. "--valid_batch_bins",
  664. type=int_or_none,
  665. default=None,
  666. help="If not given, the value of --batch_bins is used",
  667. )
  668. group.add_argument("--train_shape_file", type=str, action="append", default=[])
  669. group.add_argument("--valid_shape_file", type=str, action="append", default=[])
  670. group = parser.add_argument_group("Sequence iterator related")
  671. _batch_type_help = ""
  672. for key, value in BATCH_TYPES.items():
  673. _batch_type_help += f'"{key}":\n{value}\n'
  674. group.add_argument(
  675. "--batch_type",
  676. type=str,
  677. default="length",
  678. choices=list(BATCH_TYPES),
  679. help=_batch_type_help,
  680. )
  681. group.add_argument(
  682. "--valid_batch_type",
  683. type=str_or_none,
  684. default=None,
  685. choices=list(BATCH_TYPES) + [None],
  686. help="If not given, the value of --batch_type is used",
  687. )
  688. group.add_argument(
  689. "--speech_length_min",
  690. type=int,
  691. default=-1,
  692. help="speech length min",
  693. )
  694. group.add_argument(
  695. "--speech_length_max",
  696. type=int,
  697. default=-1,
  698. help="speech length max",
  699. )
  700. group.add_argument("--fold_length", type=int, action="append", default=[])
  701. group.add_argument(
  702. "--sort_in_batch",
  703. type=str,
  704. default="descending",
  705. choices=["descending", "ascending"],
  706. help="Sort the samples in each mini-batches by the sample "
  707. 'lengths. To enable this, "shape_file" must have the length information.',
  708. )
  709. group.add_argument(
  710. "--sort_batch",
  711. type=str,
  712. default="descending",
  713. choices=["descending", "ascending"],
  714. help="Sort mini-batches by the sample lengths",
  715. )
  716. group.add_argument(
  717. "--multiple_iterator",
  718. type=str2bool,
  719. default=False,
  720. help="Use multiple iterator mode",
  721. )
  722. group = parser.add_argument_group("Chunk iterator related")
  723. group.add_argument(
  724. "--chunk_length",
  725. type=str_or_int,
  726. default=500,
  727. help="Specify chunk length. e.g. '300', '300,400,500', or '300-400'."
  728. "If multiple numbers separated by command are given, "
  729. "one of them is selected randomly for each samples. "
  730. "If two numbers are given with '-', it indicates the range of the choices. "
  731. "Note that if the sequence length is shorter than the all chunk_lengths, "
  732. "the sample is discarded. ",
  733. )
  734. group.add_argument(
  735. "--chunk_shift_ratio",
  736. type=float,
  737. default=0.5,
  738. help="Specify the shift width of chunks. If it's less than 1, "
  739. "allows the overlapping and if bigger than 1, there are some gaps "
  740. "between each chunk.",
  741. )
  742. group.add_argument(
  743. "--num_cache_chunks",
  744. type=int,
  745. default=1024,
  746. help="Shuffle in the specified number of chunks and generate mini-batches "
  747. "More larger this value, more randomness can be obtained.",
  748. )
  749. group = parser.add_argument_group("Dataset related")
  750. _data_path_and_name_and_type_help = (
  751. "Give three words splitted by comma. It's used for the training data. "
  752. "e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. "
  753. "The first value, some/path/a.scp, indicates the file path, "
  754. "and the second, foo, is the key name used for the mini-batch data, "
  755. "and the last, sound, decides the file type. "
  756. "This option is repeatable, so you can input any number of features "
  757. "for your task. Supported file types are as follows:\n\n"
  758. )
  759. for key, dic in DATA_TYPES.items():
  760. _data_path_and_name_and_type_help += f'"{key}":\n{dic["help"]}\n\n'
  761. # for large dataset
  762. group.add_argument(
  763. "--dataset_type",
  764. type=str,
  765. default="small",
  766. help="whether to use dataloader for large dataset",
  767. )
  768. parser.add_argument(
  769. "--dataset_conf",
  770. action=NestedDictAction,
  771. default=dict(),
  772. help=f"The keyword arguments for dataset",
  773. )
  774. group.add_argument(
  775. "--train_data_file",
  776. type=str,
  777. default=None,
  778. help="train_list for large dataset",
  779. )
  780. group.add_argument(
  781. "--valid_data_file",
  782. type=str,
  783. default=None,
  784. help="valid_list for large dataset",
  785. )
  786. group.add_argument(
  787. "--train_data_path_and_name_and_type",
  788. type=str2triple_str,
  789. action="append",
  790. default=[],
  791. help=_data_path_and_name_and_type_help,
  792. )
  793. group.add_argument(
  794. "--valid_data_path_and_name_and_type",
  795. type=str2triple_str,
  796. action="append",
  797. default=[],
  798. )
  799. group.add_argument(
  800. "--allow_variable_data_keys",
  801. type=str2bool,
  802. default=False,
  803. help="Allow the arbitrary keys for mini-batch with ignoring "
  804. "the task requirements",
  805. )
  806. group.add_argument(
  807. "--max_cache_size",
  808. type=humanfriendly.parse_size,
  809. default=0.0,
  810. help="The maximum cache size for data loader. e.g. 10MB, 20GB.",
  811. )
  812. group.add_argument(
  813. "--max_cache_fd",
  814. type=int,
  815. default=32,
  816. help="The maximum number of file descriptors to be kept "
  817. "as opened for ark files. "
  818. "This feature is only valid when data type is 'kaldi_ark'.",
  819. )
  820. group.add_argument(
  821. "--valid_max_cache_size",
  822. type=humanfriendly_parse_size_or_none,
  823. default=None,
  824. help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. "
  825. "If None, the 5 percent size of --max_cache_size",
  826. )
  827. group = parser.add_argument_group("Optimizer related")
  828. for i in range(1, cls.num_optimizers + 1):
  829. suf = "" if i == 1 else str(i)
  830. group.add_argument(
  831. f"--optim{suf}",
  832. type=lambda x: x.lower(),
  833. default="adadelta",
  834. choices=list(optim_classes),
  835. help="The optimizer type",
  836. )
  837. group.add_argument(
  838. f"--optim{suf}_conf",
  839. action=NestedDictAction,
  840. default=dict(),
  841. help="The keyword arguments for optimizer",
  842. )
  843. group.add_argument(
  844. f"--scheduler{suf}",
  845. type=lambda x: str_or_none(x.lower()),
  846. default=None,
  847. choices=list(scheduler_classes) + [None],
  848. help="The lr scheduler type",
  849. )
  850. group.add_argument(
  851. f"--scheduler{suf}_conf",
  852. action=NestedDictAction,
  853. default=dict(),
  854. help="The keyword arguments for lr scheduler",
  855. )
  856. # for training on PAI
  857. group = parser.add_argument_group("PAI training related")
  858. group.add_argument(
  859. "--use_pai",
  860. type=str2bool,
  861. default=False,
  862. help="flag to indicate whether training on PAI",
  863. )
  864. group.add_argument(
  865. "--simple_ddp",
  866. type=str2bool,
  867. default=False,
  868. )
  869. group.add_argument(
  870. "--num_worker_count",
  871. type=int,
  872. default=1,
  873. help="The number of machines on PAI.",
  874. )
  875. group.add_argument(
  876. "--access_key_id",
  877. type=str,
  878. default=None,
  879. help="The username for oss.",
  880. )
  881. group.add_argument(
  882. "--access_key_secret",
  883. type=str,
  884. default=None,
  885. help="The password for oss.",
  886. )
  887. group.add_argument(
  888. "--endpoint",
  889. type=str,
  890. default=None,
  891. help="The endpoint for oss.",
  892. )
  893. group.add_argument(
  894. "--bucket_name",
  895. type=str,
  896. default=None,
  897. help="The bucket name for oss.",
  898. )
  899. group.add_argument(
  900. "--oss_bucket",
  901. default=None,
  902. help="oss bucket.",
  903. )
  904. cls.trainer.add_arguments(parser)
  905. cls.add_task_arguments(parser)
  906. assert check_return_type(parser)
  907. return parser
  908. @classmethod
  909. def build_optimizers(
  910. cls,
  911. args: argparse.Namespace,
  912. model: torch.nn.Module,
  913. ) -> List[torch.optim.Optimizer]:
  914. if cls.num_optimizers != 1:
  915. raise RuntimeError(
  916. "build_optimizers() must be overridden if num_optimizers != 1"
  917. )
  918. optim_class = optim_classes.get(args.optim)
  919. if optim_class is None:
  920. raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
  921. if args.sharded_ddp:
  922. if fairscale is None:
  923. raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'")
  924. optim = fairscale.optim.oss.OSS(
  925. params=model.parameters(), optim=optim_class, **args.optim_conf
  926. )
  927. else:
  928. optim = optim_class(model.parameters(), **args.optim_conf)
  929. optimizers = [optim]
  930. return optimizers
  931. @classmethod
  932. def exclude_opts(cls) -> Tuple[str, ...]:
  933. """The options not to be shown by --print_config"""
  934. return "required", "print_config", "config", "ngpu"
  935. @classmethod
  936. def get_default_config(cls) -> Dict[str, Any]:
  937. """Return the configuration as dict.
  938. This method is used by print_config()
  939. """
  940. def get_class_type(name: str, classes: dict):
  941. _cls = classes.get(name)
  942. if _cls is None:
  943. raise ValueError(f"must be one of {list(classes)}: {name}")
  944. return _cls
  945. # This method is used only for --print_config
  946. assert check_argument_types()
  947. parser = cls.get_parser()
  948. args, _ = parser.parse_known_args()
  949. config = vars(args)
  950. # Excludes the options not to be shown
  951. for k in AbsTask.exclude_opts():
  952. config.pop(k)
  953. for i in range(1, cls.num_optimizers + 1):
  954. suf = "" if i == 1 else str(i)
  955. name = config[f"optim{suf}"]
  956. optim_class = get_class_type(name, optim_classes)
  957. conf = get_default_kwargs(optim_class)
  958. # Overwrite the default by the arguments,
  959. conf.update(config[f"optim{suf}_conf"])
  960. # and set it again
  961. config[f"optim{suf}_conf"] = conf
  962. name = config[f"scheduler{suf}"]
  963. if name is not None:
  964. scheduler_class = get_class_type(name, scheduler_classes)
  965. conf = get_default_kwargs(scheduler_class)
  966. # Overwrite the default by the arguments,
  967. conf.update(config[f"scheduler{suf}_conf"])
  968. # and set it again
  969. config[f"scheduler{suf}_conf"] = conf
  970. for class_choices in cls.class_choices_list:
  971. if getattr(args, class_choices.name) is not None:
  972. class_obj = class_choices.get_class(getattr(args, class_choices.name))
  973. conf = get_default_kwargs(class_obj)
  974. name = class_choices.name
  975. # Overwrite the default by the arguments,
  976. conf.update(config[f"{name}_conf"])
  977. # and set it again
  978. config[f"{name}_conf"] = conf
  979. return config
  980. @classmethod
  981. def check_required_command_args(cls, args: argparse.Namespace):
  982. assert check_argument_types()
  983. if hasattr(args, "required"):
  984. for k in vars(args):
  985. if "-" in k:
  986. raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")')
  987. required = ", ".join(
  988. f"--{a}" for a in args.required if getattr(args, a) is None
  989. )
  990. if len(required) != 0:
  991. parser = cls.get_parser()
  992. parser.print_help(file=sys.stderr)
  993. p = Path(sys.argv[0]).name
  994. print(file=sys.stderr)
  995. print(
  996. f"{p}: error: the following arguments are required: " f"{required}",
  997. file=sys.stderr,
  998. )
  999. sys.exit(2)
  1000. @classmethod
  1001. def check_task_requirements(
  1002. cls,
  1003. dataset: Union[AbsDataset, IterableESPnetDataset],
  1004. allow_variable_data_keys: bool,
  1005. train: bool,
  1006. inference: bool = False,
  1007. ) -> None:
  1008. """Check if the dataset satisfy the requirement of current Task"""
  1009. assert check_argument_types()
  1010. mes = (
  1011. f"If you intend to use an additional input, modify "
  1012. f'"{cls.__name__}.required_data_names()" or '
  1013. f'"{cls.__name__}.optional_data_names()". '
  1014. f"Otherwise you need to set --allow_variable_data_keys true "
  1015. )
  1016. for k in cls.required_data_names(train, inference):
  1017. if not dataset.has_name(k):
  1018. raise RuntimeError(
  1019. f'"{cls.required_data_names(train, inference)}" are required for'
  1020. f' {cls.__name__}. but "{dataset.names()}" are input.\n{mes}'
  1021. )
  1022. if not allow_variable_data_keys:
  1023. task_keys = cls.required_data_names(
  1024. train, inference
  1025. ) + cls.optional_data_names(train, inference)
  1026. for k in dataset.names():
  1027. if k not in task_keys:
  1028. raise RuntimeError(
  1029. f"The data-name must be one of {task_keys} "
  1030. f'for {cls.__name__}: "{k}" is not allowed.\n{mes}'
  1031. )
  1032. @classmethod
  1033. def print_config(cls, file=sys.stdout) -> None:
  1034. assert check_argument_types()
  1035. # Shows the config: e.g. python train.py asr --print_config
  1036. config = cls.get_default_config()
  1037. file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False))
  1038. @classmethod
  1039. def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
  1040. assert check_argument_types()
  1041. print(get_commandline_args(), file=sys.stderr)
  1042. if args is None:
  1043. parser = cls.get_parser()
  1044. args = parser.parse_args(cmd)
  1045. args.version = __version__
  1046. if args.pretrain_path is not None:
  1047. raise RuntimeError("--pretrain_path is deprecated. Use --init_param")
  1048. if args.print_config:
  1049. cls.print_config()
  1050. sys.exit(0)
  1051. cls.check_required_command_args(args)
  1052. if not args.distributed or not args.multiprocessing_distributed:
  1053. cls.main_worker(args)
  1054. else:
  1055. assert args.ngpu > 1
  1056. cls.main_worker(args)
  1057. @classmethod
  1058. def run(cls):
  1059. assert hasattr(cls, "finetune_args")
  1060. args = cls.finetune_args
  1061. args.train_shape_file = None
  1062. if args.distributed:
  1063. args.simple_ddp = True
  1064. else:
  1065. args.simple_ddp = False
  1066. args.ngpu = 1
  1067. args.use_pai = False
  1068. args.batch_type = "length"
  1069. args.oss_bucket = None
  1070. args.input_size = None
  1071. cls.main_worker(args)
  1072. @classmethod
  1073. def main_worker(cls, args: argparse.Namespace):
  1074. assert check_argument_types()
  1075. # 0. Init distributed process
  1076. distributed_option = build_dataclass(DistributedOption, args)
  1077. # Setting distributed_option.dist_rank, etc.
  1078. if args.use_pai:
  1079. distributed_option.init_options_pai()
  1080. elif not args.simple_ddp:
  1081. distributed_option.init_options()
  1082. # Invoking torch.distributed.init_process_group
  1083. if args.use_pai:
  1084. distributed_option.init_torch_distributed_pai(args)
  1085. elif not args.simple_ddp:
  1086. distributed_option.init_torch_distributed(args)
  1087. elif args.distributed and args.simple_ddp:
  1088. distributed_option.init_torch_distributed_pai(args)
  1089. args.ngpu = dist.get_world_size()
  1090. if args.dataset_type == "small" and args.ngpu > 0:
  1091. if args.batch_size is not None:
  1092. args.batch_size = args.batch_size * args.ngpu
  1093. if args.batch_bins is not None and args.ngpu > 0:
  1094. args.batch_bins = args.batch_bins * args.ngpu
  1095. # filter samples if wav.scp and text are mismatch
  1096. if (
  1097. args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
  1098. if not args.simple_ddp or distributed_option.dist_rank == 0:
  1099. filter_wav_text(args.data_dir, args.train_set)
  1100. filter_wav_text(args.data_dir, args.dev_set)
  1101. if args.simple_ddp:
  1102. dist.barrier()
  1103. if args.train_shape_file is None and args.dataset_type == "small":
  1104. if not args.simple_ddp or distributed_option.dist_rank == 0:
  1105. calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
  1106. args.speech_length_max)
  1107. calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
  1108. args.speech_length_max)
  1109. if args.simple_ddp:
  1110. dist.barrier()
  1111. args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
  1112. args.valid_shape_file = [os.path.join(args.data_dir, args.dev_set, "speech_shape")]
  1113. if args.train_data_file is None and args.dataset_type == "large":
  1114. if not args.simple_ddp or distributed_option.dist_rank == 0:
  1115. generate_data_list(args.data_dir, args.train_set)
  1116. generate_data_list(args.data_dir, args.dev_set)
  1117. if args.simple_ddp:
  1118. dist.barrier()
  1119. args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
  1120. args.valid_data_file = os.path.join(args.data_dir, args.dev_set, "data.list")
  1121. # NOTE(kamo): Don't use logging before invoking logging.basicConfig()
  1122. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  1123. if not distributed_option.distributed:
  1124. _rank = ""
  1125. else:
  1126. _rank = (
  1127. f":{distributed_option.dist_rank}/"
  1128. f"{distributed_option.dist_world_size}"
  1129. )
  1130. # NOTE(kamo):
  1131. # logging.basicConfig() is invoked in main_worker() instead of main()
  1132. # because it can be invoked only once in a process.
  1133. # FIXME(kamo): Should we use logging.getLogger()?
  1134. # BUGFIX: Remove previous handlers and reset log level
  1135. for handler in logging.root.handlers[:]:
  1136. logging.root.removeHandler(handler)
  1137. logging.basicConfig(
  1138. level=args.log_level,
  1139. format=f"[{os.uname()[1].split('.')[0]}]"
  1140. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1141. )
  1142. else:
  1143. # BUGFIX: Remove previous handlers and reset log level
  1144. for handler in logging.root.handlers[:]:
  1145. logging.root.removeHandler(handler)
  1146. # Suppress logging if RANK != 0
  1147. logging.basicConfig(
  1148. level="ERROR",
  1149. format=f"[{os.uname()[1].split('.')[0]}]"
  1150. f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  1151. )
  1152. logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
  1153. distributed_option.dist_rank,
  1154. distributed_option.local_rank))
  1155. # 1. Set random-seed
  1156. set_all_random_seed(args.seed)
  1157. torch.backends.cudnn.enabled = args.cudnn_enabled
  1158. torch.backends.cudnn.benchmark = args.cudnn_benchmark
  1159. torch.backends.cudnn.deterministic = args.cudnn_deterministic
  1160. if args.detect_anomaly:
  1161. logging.info("Invoking torch.autograd.set_detect_anomaly(True)")
  1162. torch.autograd.set_detect_anomaly(args.detect_anomaly)
  1163. # 2. Build model
  1164. model = cls.build_model(args=args)
  1165. if not isinstance(model, FunASRModel):
  1166. raise RuntimeError(
  1167. f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
  1168. )
  1169. model = model.to(
  1170. dtype=getattr(torch, args.train_dtype),
  1171. device="cuda" if args.ngpu > 0 else "cpu",
  1172. )
  1173. for t in args.freeze_param:
  1174. for k, p in model.named_parameters():
  1175. if k.startswith(t + ".") or k == t:
  1176. logging.info(f"Setting {k}.requires_grad = False")
  1177. p.requires_grad = False
  1178. # 3. Build optimizer
  1179. optimizers = cls.build_optimizers(args, model=model)
  1180. # 4. Build schedulers
  1181. schedulers = []
  1182. for i, optim in enumerate(optimizers, 1):
  1183. suf = "" if i == 1 else str(i)
  1184. name = getattr(args, f"scheduler{suf}")
  1185. conf = getattr(args, f"scheduler{suf}_conf")
  1186. if name is not None:
  1187. cls_ = scheduler_classes.get(name)
  1188. if cls_ is None:
  1189. raise ValueError(
  1190. f"must be one of {list(scheduler_classes)}: {name}"
  1191. )
  1192. scheduler = cls_(optim, **conf)
  1193. else:
  1194. scheduler = None
  1195. schedulers.append(scheduler)
  1196. logging.info(pytorch_cudnn_version())
  1197. logging.info(model_summary(model))
  1198. for i, (o, s) in enumerate(zip(optimizers, schedulers), 1):
  1199. suf = "" if i == 1 else str(i)
  1200. logging.info(f"Optimizer{suf}:\n{o}")
  1201. logging.info(f"Scheduler{suf}: {s}")
  1202. # 5. Dump "args" to config.yaml
  1203. # NOTE(kamo): "args" should be saved after object-buildings are done
  1204. # because they are allowed to modify "args".
  1205. output_dir = Path(args.output_dir)
  1206. if not distributed_option.distributed or distributed_option.dist_rank == 0:
  1207. output_dir.mkdir(parents=True, exist_ok=True)
  1208. with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
  1209. logging.info(
  1210. f'Saving the configuration in {output_dir / "config.yaml"}'
  1211. )
  1212. if args.use_pai:
  1213. buffer = BytesIO()
  1214. torch.save({"config": vars(args)}, buffer)
  1215. args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
  1216. else:
  1217. yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
  1218. if args.dry_run:
  1219. pass
  1220. elif args.collect_stats:
  1221. # Perform on collect_stats mode. This mode has two roles
  1222. # - Derive the length and dimension of all input data
  1223. # - Accumulate feats, square values, and the length for whitening
  1224. if args.valid_batch_size is None:
  1225. args.valid_batch_size = args.batch_size
  1226. if len(args.train_shape_file) != 0:
  1227. train_key_file = args.train_shape_file[0]
  1228. else:
  1229. train_key_file = None
  1230. if len(args.valid_shape_file) != 0:
  1231. valid_key_file = args.valid_shape_file[0]
  1232. else:
  1233. valid_key_file = None
  1234. collect_stats(
  1235. model=model,
  1236. train_iter=cls.build_streaming_iterator(
  1237. data_path_and_name_and_type=args.train_data_path_and_name_and_type,
  1238. key_file=train_key_file,
  1239. batch_size=args.batch_size,
  1240. mc=args.mc,
  1241. dtype=args.train_dtype,
  1242. num_workers=args.num_workers,
  1243. allow_variable_data_keys=args.allow_variable_data_keys,
  1244. ngpu=args.ngpu,
  1245. preprocess_fn=cls.build_preprocess_fn(args, train=False),
  1246. collate_fn=cls.build_collate_fn(args, train=False),
  1247. ),
  1248. valid_iter=cls.build_streaming_iterator(
  1249. data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
  1250. key_file=valid_key_file,
  1251. batch_size=args.valid_batch_size,
  1252. mc=args.mc,
  1253. dtype=args.train_dtype,
  1254. num_workers=args.num_workers,
  1255. allow_variable_data_keys=args.allow_variable_data_keys,
  1256. ngpu=args.ngpu,
  1257. preprocess_fn=cls.build_preprocess_fn(args, train=False),
  1258. collate_fn=cls.build_collate_fn(args, train=False),
  1259. ),
  1260. output_dir=output_dir,
  1261. ngpu=args.ngpu,
  1262. log_interval=args.log_interval,
  1263. write_collected_feats=args.write_collected_feats,
  1264. )
  1265. else:
  1266. logging.info("Training args: {}".format(args))
  1267. # 6. Loads pre-trained model
  1268. for p in args.init_param:
  1269. logging.info(f"Loading pretrained params from {p}")
  1270. load_pretrained_model(
  1271. model=model,
  1272. init_param=p,
  1273. ignore_init_mismatch=args.ignore_init_mismatch,
  1274. # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
  1275. # in PyTorch<=1.4
  1276. map_location=f"cuda:{torch.cuda.current_device()}"
  1277. if args.ngpu > 0
  1278. else "cpu",
  1279. oss_bucket=args.oss_bucket,
  1280. )
  1281. # 7. Build iterator factories
  1282. if args.dataset_type == "large":
  1283. from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
  1284. train_iter_factory = LargeDataLoader(args, mode="train")
  1285. valid_iter_factory = LargeDataLoader(args, mode="eval")
  1286. elif args.dataset_type == "small":
  1287. train_iter_factory = cls.build_iter_factory(
  1288. args=args,
  1289. distributed_option=distributed_option,
  1290. mode="train",
  1291. )
  1292. valid_iter_factory = cls.build_iter_factory(
  1293. args=args,
  1294. distributed_option=distributed_option,
  1295. mode="valid",
  1296. )
  1297. else:
  1298. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  1299. if args.scheduler == "tri_stage":
  1300. for scheduler in schedulers:
  1301. scheduler.init_tri_stage_scheudler(max_update=args.max_update)
  1302. # 8. Start training
  1303. if args.use_wandb:
  1304. if wandb is None:
  1305. raise RuntimeError("Please install wandb")
  1306. try:
  1307. wandb.login()
  1308. except wandb.errors.UsageError:
  1309. logging.info("wandb not configured! run `wandb login` to enable")
  1310. args.use_wandb = False
  1311. if args.use_wandb:
  1312. if (
  1313. not distributed_option.distributed
  1314. or distributed_option.dist_rank == 0
  1315. ):
  1316. if args.wandb_project is None:
  1317. project = "FunASR_" + cls.__name__
  1318. else:
  1319. project = args.wandb_project
  1320. if args.wandb_name is None:
  1321. name = str(Path(".").resolve()).replace("/", "_")
  1322. else:
  1323. name = args.wandb_name
  1324. wandb.init(
  1325. entity=args.wandb_entity,
  1326. project=project,
  1327. name=name,
  1328. dir=output_dir,
  1329. id=args.wandb_id,
  1330. resume="allow",
  1331. )
  1332. wandb.config.update(args)
  1333. else:
  1334. # wandb also supports grouping for distributed training,
  1335. # but we only logs aggregated data,
  1336. # so it's enough to perform on rank0 node.
  1337. args.use_wandb = False
  1338. # Don't give args to trainer.run() directly!!!
  1339. # Instead of it, define "Options" object and build here.
  1340. trainer_options = cls.trainer.build_options(args)
  1341. cls.trainer.run(
  1342. model=model,
  1343. optimizers=optimizers,
  1344. schedulers=schedulers,
  1345. train_iter_factory=train_iter_factory,
  1346. valid_iter_factory=valid_iter_factory,
  1347. trainer_options=trainer_options,
  1348. distributed_option=distributed_option,
  1349. )
  1350. if args.use_wandb and wandb.run:
  1351. wandb.finish()
  1352. @classmethod
  1353. def build_iter_options(
  1354. cls,
  1355. args: argparse.Namespace,
  1356. distributed_option: DistributedOption,
  1357. mode: str,
  1358. ):
  1359. if mode == "train":
  1360. preprocess_fn = cls.build_preprocess_fn(args, train=True)
  1361. collate_fn = cls.build_collate_fn(args, train=True)
  1362. data_path_and_name_and_type = args.train_data_path_and_name_and_type
  1363. shape_files = args.train_shape_file
  1364. batch_size = args.batch_size
  1365. batch_bins = args.batch_bins
  1366. batch_type = args.batch_type
  1367. max_cache_size = args.max_cache_size
  1368. max_cache_fd = args.max_cache_fd
  1369. distributed = distributed_option.distributed
  1370. num_batches = None
  1371. num_iters_per_epoch = args.num_iters_per_epoch
  1372. train = True
  1373. elif mode == "valid":
  1374. preprocess_fn = cls.build_preprocess_fn(args, train=False)
  1375. collate_fn = cls.build_collate_fn(args, train=False)
  1376. data_path_and_name_and_type = args.valid_data_path_and_name_and_type
  1377. shape_files = args.valid_shape_file
  1378. if args.valid_batch_type is None:
  1379. batch_type = args.batch_type
  1380. else:
  1381. batch_type = args.valid_batch_type
  1382. if args.valid_batch_size is None:
  1383. batch_size = args.batch_size
  1384. else:
  1385. batch_size = args.valid_batch_size
  1386. if args.valid_batch_bins is None:
  1387. batch_bins = args.batch_bins
  1388. else:
  1389. batch_bins = args.valid_batch_bins
  1390. if args.valid_max_cache_size is None:
  1391. # Cache 5% of maximum size for validation loader
  1392. max_cache_size = 0.05 * args.max_cache_size
  1393. else:
  1394. max_cache_size = args.valid_max_cache_size
  1395. max_cache_fd = args.max_cache_fd
  1396. distributed = distributed_option.distributed
  1397. num_batches = None
  1398. num_iters_per_epoch = None
  1399. train = False
  1400. else:
  1401. raise NotImplementedError(f"mode={mode}")
  1402. return IteratorOptions(
  1403. preprocess_fn=preprocess_fn,
  1404. collate_fn=collate_fn,
  1405. data_path_and_name_and_type=data_path_and_name_and_type,
  1406. shape_files=shape_files,
  1407. batch_type=batch_type,
  1408. batch_size=batch_size,
  1409. batch_bins=batch_bins,
  1410. num_batches=num_batches,
  1411. max_cache_size=max_cache_size,
  1412. max_cache_fd=max_cache_fd,
  1413. distributed=distributed,
  1414. num_iters_per_epoch=num_iters_per_epoch,
  1415. train=train,
  1416. )
  1417. @classmethod
  1418. def build_iter_factory(
  1419. cls,
  1420. args: argparse.Namespace,
  1421. distributed_option: DistributedOption,
  1422. mode: str,
  1423. kwargs: dict = None,
  1424. ) -> AbsIterFactory:
  1425. """Build a factory object of mini-batch iterator.
  1426. This object is invoked at every epochs to build the iterator for each epoch
  1427. as following:
  1428. >>> iter_factory = cls.build_iter_factory(...)
  1429. >>> for epoch in range(1, max_epoch):
  1430. ... for keys, batch in iter_fatory.build_iter(epoch):
  1431. ... model(**batch)
  1432. The mini-batches for each epochs are fully controlled by this class.
  1433. Note that the random seed used for shuffling is decided as "seed + epoch" and
  1434. the generated mini-batches can be reproduces when resuming.
  1435. Note that the definition of "epoch" doesn't always indicate
  1436. to run out of the whole training corpus.
  1437. "--num_iters_per_epoch" option restricts the number of iterations for each epoch
  1438. and the rest of samples for the originally epoch are left for the next epoch.
  1439. e.g. If The number of mini-batches equals to 4, the following two are same:
  1440. - 1 epoch without "--num_iters_per_epoch"
  1441. - 4 epoch with "--num_iters_per_epoch" == 4
  1442. """
  1443. assert check_argument_types()
  1444. iter_options = cls.build_iter_options(args, distributed_option, mode)
  1445. # Overwrite iter_options if any kwargs is given
  1446. if kwargs is not None:
  1447. for k, v in kwargs.items():
  1448. setattr(iter_options, k, v)
  1449. if args.iterator_type == "sequence":
  1450. return cls.build_sequence_iter_factory(
  1451. args=args,
  1452. iter_options=iter_options,
  1453. mode=mode,
  1454. )
  1455. elif args.iterator_type == "chunk":
  1456. return cls.build_chunk_iter_factory(
  1457. args=args,
  1458. iter_options=iter_options,
  1459. mode=mode,
  1460. )
  1461. elif args.iterator_type == "task":
  1462. return cls.build_task_iter_factory(
  1463. args=args,
  1464. iter_options=iter_options,
  1465. mode=mode,
  1466. )
  1467. else:
  1468. raise RuntimeError(f"Not supported: iterator_type={args.iterator_type}")
  1469. @classmethod
  1470. def build_sequence_iter_factory(
  1471. cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str
  1472. ) -> AbsIterFactory:
  1473. assert check_argument_types()
  1474. if hasattr(args, "frontend_conf"):
  1475. if args.frontend_conf is not None and "fs" in args.frontend_conf:
  1476. dest_sample_rate = args.frontend_conf["fs"]
  1477. else:
  1478. dest_sample_rate = 16000
  1479. else:
  1480. dest_sample_rate = 16000
  1481. dataset = ESPnetDataset(
  1482. iter_options.data_path_and_name_and_type,
  1483. float_dtype=args.train_dtype,
  1484. preprocess=iter_options.preprocess_fn,
  1485. max_cache_size=iter_options.max_cache_size,
  1486. max_cache_fd=iter_options.max_cache_fd,
  1487. dest_sample_rate=dest_sample_rate,
  1488. )
  1489. cls.check_task_requirements(
  1490. dataset, args.allow_variable_data_keys, train=iter_options.train
  1491. )
  1492. if Path(
  1493. Path(iter_options.data_path_and_name_and_type[0][0]).parent, "utt2category"
  1494. ).exists():
  1495. utt2category_file = str(
  1496. Path(
  1497. Path(iter_options.data_path_and_name_and_type[0][0]).parent,
  1498. "utt2category",
  1499. )
  1500. )
  1501. else:
  1502. utt2category_file = None
  1503. batch_sampler = build_batch_sampler(
  1504. type=iter_options.batch_type,
  1505. shape_files=iter_options.shape_files,
  1506. fold_lengths=args.fold_length,
  1507. batch_size=iter_options.batch_size,
  1508. batch_bins=iter_options.batch_bins,
  1509. sort_in_batch=args.sort_in_batch,
  1510. sort_batch=args.sort_batch,
  1511. drop_last=False,
  1512. min_batch_size=torch.distributed.get_world_size()
  1513. if iter_options.distributed
  1514. else 1,
  1515. utt2category_file=utt2category_file,
  1516. )
  1517. batches = list(batch_sampler)
  1518. if iter_options.num_batches is not None:
  1519. batches = batches[: iter_options.num_batches]
  1520. bs_list = [len(batch) for batch in batches]
  1521. logging.info(f"[{mode}] dataset:\n{dataset}")
  1522. logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
  1523. logging.info(
  1524. f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
  1525. f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
  1526. )
  1527. if args.scheduler == "tri_stage" and mode == "train":
  1528. args.max_update = len(bs_list) * args.max_epoch
  1529. logging.info("Max update: {}".format(args.max_update))
  1530. if iter_options.distributed:
  1531. world_size = torch.distributed.get_world_size()
  1532. rank = torch.distributed.get_rank()
  1533. for batch in batches:
  1534. if len(batch) < world_size:
  1535. raise RuntimeError(
  1536. f"The batch-size must be equal or more than world_size: "
  1537. f"{len(batch)} < {world_size}"
  1538. )
  1539. batches = [batch[rank::world_size] for batch in batches]
  1540. return SequenceIterFactory(
  1541. dataset=dataset,
  1542. batches=batches,
  1543. seed=args.seed,
  1544. num_iters_per_epoch=iter_options.num_iters_per_epoch,
  1545. shuffle=iter_options.train,
  1546. num_workers=args.num_workers,
  1547. collate_fn=iter_options.collate_fn,
  1548. pin_memory=args.ngpu > 0,
  1549. )
  1550. @classmethod
  1551. def build_chunk_iter_factory(
  1552. cls,
  1553. args: argparse.Namespace,
  1554. iter_options: IteratorOptions,
  1555. mode: str,
  1556. ) -> AbsIterFactory:
  1557. assert check_argument_types()
  1558. dataset = ESPnetDataset(
  1559. iter_options.data_path_and_name_and_type,
  1560. float_dtype=args.train_dtype,
  1561. preprocess=iter_options.preprocess_fn,
  1562. max_cache_size=iter_options.max_cache_size,
  1563. max_cache_fd=iter_options.max_cache_fd,
  1564. )
  1565. cls.check_task_requirements(
  1566. dataset, args.allow_variable_data_keys, train=iter_options.train
  1567. )
  1568. if len(iter_options.shape_files) == 0:
  1569. key_file = iter_options.data_path_and_name_and_type[0][0]
  1570. else:
  1571. key_file = iter_options.shape_files[0]
  1572. batch_sampler = UnsortedBatchSampler(batch_size=1, key_file=key_file)
  1573. batches = list(batch_sampler)
  1574. if iter_options.num_batches is not None:
  1575. batches = batches[: iter_options.num_batches]
  1576. logging.info(f"[{mode}] dataset:\n{dataset}")
  1577. if iter_options.distributed:
  1578. world_size = torch.distributed.get_world_size()
  1579. rank = torch.distributed.get_rank()
  1580. if len(batches) < world_size:
  1581. raise RuntimeError("Number of samples is smaller than world_size")
  1582. if iter_options.batch_size < world_size:
  1583. raise RuntimeError("batch_size must be equal or more than world_size")
  1584. if rank < iter_options.batch_size % world_size:
  1585. batch_size = iter_options.batch_size // world_size + 1
  1586. else:
  1587. batch_size = iter_options.batch_size // world_size
  1588. num_cache_chunks = args.num_cache_chunks // world_size
  1589. # NOTE(kamo): Split whole corpus by sample numbers without considering
  1590. # each of the lengths, therefore the number of iteration counts are not
  1591. # always equal to each other and the iterations are limitted
  1592. # by the fewest iterations.
  1593. # i.e. the samples over the counts are discarded.
  1594. batches = batches[rank::world_size]
  1595. else:
  1596. batch_size = iter_options.batch_size
  1597. num_cache_chunks = args.num_cache_chunks
  1598. return ChunkIterFactory(
  1599. dataset=dataset,
  1600. batches=batches,
  1601. seed=args.seed,
  1602. batch_size=batch_size,
  1603. # For chunk iterator,
  1604. # --num_iters_per_epoch doesn't indicate the number of iterations,
  1605. # but indicates the number of samples.
  1606. num_samples_per_epoch=iter_options.num_iters_per_epoch,
  1607. shuffle=iter_options.train,
  1608. num_workers=args.num_workers,
  1609. collate_fn=iter_options.collate_fn,
  1610. pin_memory=args.ngpu > 0,
  1611. chunk_length=args.chunk_length,
  1612. chunk_shift_ratio=args.chunk_shift_ratio,
  1613. num_cache_chunks=num_cache_chunks,
  1614. )
  1615. # NOTE(kamo): Not abstract class
  1616. @classmethod
  1617. def build_task_iter_factory(
  1618. cls,
  1619. args: argparse.Namespace,
  1620. iter_options: IteratorOptions,
  1621. mode: str,
  1622. ) -> AbsIterFactory:
  1623. """Build task specific iterator factory
  1624. Example:
  1625. >>> class YourTask(AbsTask):
  1626. ... @classmethod
  1627. ... def add_task_arguments(cls, parser: argparse.ArgumentParser):
  1628. ... parser.set_defaults(iterator_type="task")
  1629. ...
  1630. ... @classmethod
  1631. ... def build_task_iter_factory(
  1632. ... cls,
  1633. ... args: argparse.Namespace,
  1634. ... iter_options: IteratorOptions,
  1635. ... mode: str,
  1636. ... ):
  1637. ... return FooIterFactory(...)
  1638. ...
  1639. ... @classmethod
  1640. ... def build_iter_options(
  1641. .... args: argparse.Namespace,
  1642. ... distributed_option: DistributedOption,
  1643. ... mode: str
  1644. ... ):
  1645. ... # if you need to customize options object
  1646. """
  1647. raise NotImplementedError
  1648. @classmethod
  1649. def build_multiple_iter_factory(
  1650. cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str
  1651. ):
  1652. assert check_argument_types()
  1653. iter_options = cls.build_iter_options(args, distributed_option, mode)
  1654. assert len(iter_options.data_path_and_name_and_type) > 0, len(
  1655. iter_options.data_path_and_name_and_type
  1656. )
  1657. # 1. Sanity check
  1658. num_splits = None
  1659. for path in [
  1660. path for path, _, _ in iter_options.data_path_and_name_and_type
  1661. ] + list(iter_options.shape_files):
  1662. if not Path(path).is_dir():
  1663. raise RuntimeError(f"{path} is not a directory")
  1664. p = Path(path) / "num_splits"
  1665. if not p.exists():
  1666. raise FileNotFoundError(f"{p} is not found")
  1667. with p.open() as f:
  1668. _num_splits = int(f.read())
  1669. if num_splits is not None and num_splits != _num_splits:
  1670. raise RuntimeError(
  1671. f"Number of splits are mismathed: "
  1672. f"{iter_options.data_path_and_name_and_type[0][0]} and {path}"
  1673. )
  1674. num_splits = _num_splits
  1675. for i in range(num_splits):
  1676. p = Path(path) / f"split.{i}"
  1677. if not p.exists():
  1678. raise FileNotFoundError(f"{p} is not found")
  1679. # 2. Create functions to build an iter factory for each splits
  1680. data_path_and_name_and_type_list = [
  1681. [
  1682. (str(Path(p) / f"split.{i}"), n, t)
  1683. for p, n, t in iter_options.data_path_and_name_and_type
  1684. ]
  1685. for i in range(num_splits)
  1686. ]
  1687. shape_files_list = [
  1688. [str(Path(s) / f"split.{i}") for s in iter_options.shape_files]
  1689. for i in range(num_splits)
  1690. ]
  1691. num_iters_per_epoch_list = [
  1692. (iter_options.num_iters_per_epoch + i) // num_splits
  1693. if iter_options.num_iters_per_epoch is not None
  1694. else None
  1695. for i in range(num_splits)
  1696. ]
  1697. max_cache_size = iter_options.max_cache_size / num_splits
  1698. # Note that iter-factories are built for each epoch at runtime lazily.
  1699. build_funcs = [
  1700. functools.partial(
  1701. cls.build_iter_factory,
  1702. args,
  1703. distributed_option,
  1704. mode,
  1705. kwargs=dict(
  1706. data_path_and_name_and_type=_data_path_and_name_and_type,
  1707. shape_files=_shape_files,
  1708. num_iters_per_epoch=_num_iters_per_epoch,
  1709. max_cache_size=max_cache_size,
  1710. ),
  1711. )
  1712. for (
  1713. _data_path_and_name_and_type,
  1714. _shape_files,
  1715. _num_iters_per_epoch,
  1716. ) in zip(
  1717. data_path_and_name_and_type_list,
  1718. shape_files_list,
  1719. num_iters_per_epoch_list,
  1720. )
  1721. ]
  1722. # 3. Build MultipleIterFactory
  1723. return MultipleIterFactory(
  1724. build_funcs=build_funcs, shuffle=iter_options.train, seed=args.seed
  1725. )
  1726. @classmethod
  1727. def build_streaming_iterator(
  1728. cls,
  1729. data_path_and_name_and_type,
  1730. preprocess_fn,
  1731. collate_fn,
  1732. key_file: str = None,
  1733. batch_size: int = 1,
  1734. fs: dict = None,
  1735. mc: bool = False,
  1736. dtype: str = np.float32,
  1737. num_workers: int = 1,
  1738. allow_variable_data_keys: bool = False,
  1739. ngpu: int = 0,
  1740. inference: bool = False,
  1741. ) -> DataLoader:
  1742. """Build DataLoader using iterable dataset"""
  1743. assert check_argument_types()
  1744. # For backward compatibility for pytorch DataLoader
  1745. if collate_fn is not None:
  1746. kwargs = dict(collate_fn=collate_fn)
  1747. else:
  1748. kwargs = {}
  1749. dataset = IterableESPnetDataset(
  1750. data_path_and_name_and_type,
  1751. float_dtype=dtype,
  1752. fs=fs,
  1753. mc=mc,
  1754. preprocess=preprocess_fn,
  1755. key_file=key_file,
  1756. )
  1757. if dataset.apply_utt2category:
  1758. kwargs.update(batch_size=1)
  1759. else:
  1760. kwargs.update(batch_size=batch_size)
  1761. cls.check_task_requirements(
  1762. dataset, allow_variable_data_keys, train=False, inference=inference
  1763. )
  1764. return DataLoader(
  1765. dataset=dataset,
  1766. pin_memory=ngpu > 0,
  1767. num_workers=num_workers,
  1768. **kwargs,
  1769. )
  1770. # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~
  1771. @classmethod
  1772. def build_model_from_file(
  1773. cls,
  1774. config_file: Union[Path, str] = None,
  1775. model_file: Union[Path, str] = None,
  1776. cmvn_file: Union[Path, str] = None,
  1777. device: str = "cpu",
  1778. ) -> Tuple[FunASRModel, argparse.Namespace]:
  1779. """Build model from the files.
  1780. This method is used for inference or fine-tuning.
  1781. Args:
  1782. config_file: The yaml file saved when training.
  1783. model_file: The model file saved when training.
  1784. device: Device type, "cpu", "cuda", or "cuda:N".
  1785. """
  1786. assert check_argument_types()
  1787. if config_file is None:
  1788. assert model_file is not None, (
  1789. "The argument 'model_file' must be provided "
  1790. "if the argument 'config_file' is not specified."
  1791. )
  1792. config_file = Path(model_file).parent / "config.yaml"
  1793. else:
  1794. config_file = Path(config_file)
  1795. with config_file.open("r", encoding="utf-8") as f:
  1796. args = yaml.safe_load(f)
  1797. if cmvn_file is not None:
  1798. args["cmvn_file"] = cmvn_file
  1799. args = argparse.Namespace(**args)
  1800. model = cls.build_model(args)
  1801. if not isinstance(model, FunASRModel):
  1802. raise RuntimeError(
  1803. f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
  1804. )
  1805. model.to(device)
  1806. if model_file is not None:
  1807. if device == "cuda":
  1808. # NOTE(kamo): "cuda" for torch.load always indicates cuda:0
  1809. # in PyTorch<=1.4
  1810. device = f"cuda:{torch.cuda.current_device()}"
  1811. model.load_state_dict(torch.load(model_file, map_location=device))
  1812. model.to(device)
  1813. return model, args