|
|
@@ -266,8 +266,9 @@ class AbsTask(ABC):
|
|
|
def build_model(cls, args: argparse.Namespace) -> FunASRModel:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
+
|
|
|
@classmethod
|
|
|
- def get_parser(cls, parser) -> config_argparse.ArgumentParser:
|
|
|
+ def get_parser(cls) -> config_argparse.ArgumentParser:
|
|
|
assert check_argument_types()
|
|
|
|
|
|
class ArgumentDefaultsRawTextHelpFormatter(
|
|
|
@@ -276,10 +277,10 @@ class AbsTask(ABC):
|
|
|
):
|
|
|
pass
|
|
|
|
|
|
- # parser = config_argparse.ArgumentParser(
|
|
|
- # description="base parser",
|
|
|
- # formatter_class=ArgumentDefaultsRawTextHelpFormatter,
|
|
|
- # )
|
|
|
+ parser = config_argparse.ArgumentParser(
|
|
|
+ description="base parser",
|
|
|
+ formatter_class=ArgumentDefaultsRawTextHelpFormatter,
|
|
|
+ )
|
|
|
|
|
|
# NOTE(kamo): Use '_' instead of '-' to avoid confusion.
|
|
|
# I think '-' looks really confusing if it's written in yaml.
|
|
|
@@ -961,702 +962,6 @@ class AbsTask(ABC):
|
|
|
assert check_return_type(parser)
|
|
|
return parser
|
|
|
|
|
|
-
|
|
|
- # @classmethod
|
|
|
- # def get_parser(cls) -> config_argparse.ArgumentParser:
|
|
|
- # assert check_argument_types()
|
|
|
- #
|
|
|
- # class ArgumentDefaultsRawTextHelpFormatter(
|
|
|
- # argparse.RawTextHelpFormatter,
|
|
|
- # argparse.ArgumentDefaultsHelpFormatter,
|
|
|
- # ):
|
|
|
- # pass
|
|
|
- #
|
|
|
- # parser = config_argparse.ArgumentParser(
|
|
|
- # description="base parser",
|
|
|
- # formatter_class=ArgumentDefaultsRawTextHelpFormatter,
|
|
|
- # )
|
|
|
- #
|
|
|
- # # NOTE(kamo): Use '_' instead of '-' to avoid confusion.
|
|
|
- # # I think '-' looks really confusing if it's written in yaml.
|
|
|
- #
|
|
|
- # # NOTE(kamo): add_arguments(..., required=True) can't be used
|
|
|
- # # to provide --print_config mode. Instead of it, do as
|
|
|
- # # parser.set_defaults(required=["output_dir"])
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("Common configuration")
|
|
|
- #
|
|
|
- # group.add_argument(
|
|
|
- # "--print_config",
|
|
|
- # action="store_true",
|
|
|
- # help="Print the config file and exit",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--log_level",
|
|
|
- # type=lambda x: x.upper(),
|
|
|
- # default="INFO",
|
|
|
- # choices=("ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
|
|
|
- # help="The verbose level of logging",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--dry_run",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Perform process without training",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--iterator_type",
|
|
|
- # type=str,
|
|
|
- # choices=["sequence", "chunk", "task", "none"],
|
|
|
- # default="sequence",
|
|
|
- # help="Specify iterator type",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group.add_argument("--output_dir", type=str_or_none, default=None)
|
|
|
- # group.add_argument(
|
|
|
- # "--ngpu",
|
|
|
- # type=int,
|
|
|
- # default=0,
|
|
|
- # help="The number of gpus. 0 indicates CPU mode",
|
|
|
- # )
|
|
|
- # group.add_argument("--seed", type=int, default=0, help="Random seed")
|
|
|
- # group.add_argument(
|
|
|
- # "--num_workers",
|
|
|
- # type=int,
|
|
|
- # default=1,
|
|
|
- # help="The number of workers used for DataLoader",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--num_att_plot",
|
|
|
- # type=int,
|
|
|
- # default=3,
|
|
|
- # help="The number images to plot the outputs from attention. "
|
|
|
- # "This option makes sense only when attention-based model. "
|
|
|
- # "We can also disable the attention plot by setting it 0",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("distributed training related")
|
|
|
- # group.add_argument(
|
|
|
- # "--dist_backend",
|
|
|
- # default="nccl",
|
|
|
- # type=str,
|
|
|
- # help="distributed backend",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--dist_init_method",
|
|
|
- # type=str,
|
|
|
- # default="env://",
|
|
|
- # help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
|
|
|
- # '"WORLD_SIZE", and "RANK" are referred.',
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--dist_world_size",
|
|
|
- # default=None,
|
|
|
- # type=int_or_none,
|
|
|
- # help="number of nodes for distributed training",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--dist_rank",
|
|
|
- # type=int_or_none,
|
|
|
- # default=None,
|
|
|
- # help="node rank for distributed training",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # # Not starting with "dist_" for compatibility to launch.py
|
|
|
- # "--local_rank",
|
|
|
- # type=int_or_none,
|
|
|
- # default=None,
|
|
|
- # help="local rank for distributed training. This option is used if "
|
|
|
- # "--multiprocessing_distributed=false",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--dist_master_addr",
|
|
|
- # default=None,
|
|
|
- # type=str_or_none,
|
|
|
- # help="The master address for distributed training. "
|
|
|
- # "This value is used when dist_init_method == 'env://'",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--dist_master_port",
|
|
|
- # default=None,
|
|
|
- # type=int_or_none,
|
|
|
- # help="The master port for distributed training"
|
|
|
- # "This value is used when dist_init_method == 'env://'",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--dist_launcher",
|
|
|
- # default=None,
|
|
|
- # type=str_or_none,
|
|
|
- # choices=["slurm", "mpi", None],
|
|
|
- # help="The launcher type for distributed training",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--multiprocessing_distributed",
|
|
|
- # default=False,
|
|
|
- # type=str2bool,
|
|
|
- # help="Use multi-processing distributed training to launch "
|
|
|
- # "N processes per node, which has N GPUs. This is the "
|
|
|
- # "fastest way to use PyTorch for either single node or "
|
|
|
- # "multi node data parallel training",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--unused_parameters",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Whether to use the find_unused_parameters in "
|
|
|
- # "torch.nn.parallel.DistributedDataParallel ",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--sharded_ddp",
|
|
|
- # default=False,
|
|
|
- # type=str2bool,
|
|
|
- # help="Enable sharded training provided by fairscale",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("cudnn mode related")
|
|
|
- # group.add_argument(
|
|
|
- # "--cudnn_enabled",
|
|
|
- # type=str2bool,
|
|
|
- # default=torch.backends.cudnn.enabled,
|
|
|
- # help="Enable CUDNN",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--cudnn_benchmark",
|
|
|
- # type=str2bool,
|
|
|
- # default=torch.backends.cudnn.benchmark,
|
|
|
- # help="Enable cudnn-benchmark mode",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--cudnn_deterministic",
|
|
|
- # type=str2bool,
|
|
|
- # default=True,
|
|
|
- # help="Enable cudnn-deterministic mode",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("collect stats mode related")
|
|
|
- # group.add_argument(
|
|
|
- # "--collect_stats",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help='Perform on "collect stats" mode',
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--mc",
|
|
|
- # type=bool,
|
|
|
- # default=False,
|
|
|
- # help="MultiChannel input",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--write_collected_feats",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help='Write the output features from the model when "collect stats" mode',
|
|
|
- # )
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("Trainer related")
|
|
|
- # group.add_argument(
|
|
|
- # "--max_epoch",
|
|
|
- # type=int,
|
|
|
- # default=40,
|
|
|
- # help="The maximum number epoch to train",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--max_update",
|
|
|
- # type=int,
|
|
|
- # default=sys.maxsize,
|
|
|
- # help="The maximum number update step to train",
|
|
|
- # )
|
|
|
- # parser.add_argument(
|
|
|
- # "--batch_interval",
|
|
|
- # type=int,
|
|
|
- # default=-1,
|
|
|
- # help="The batch interval for saving model.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--patience",
|
|
|
- # type=int_or_none,
|
|
|
- # default=None,
|
|
|
- # help="Number of epochs to wait without improvement "
|
|
|
- # "before stopping the training",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--val_scheduler_criterion",
|
|
|
- # type=str,
|
|
|
- # nargs=2,
|
|
|
- # default=("valid", "loss"),
|
|
|
- # help="The criterion used for the value given to the lr scheduler. "
|
|
|
- # 'Give a pair referring the phase, "train" or "valid",'
|
|
|
- # 'and the criterion name. The mode specifying "min" or "max" can '
|
|
|
- # "be changed by --scheduler_conf",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--early_stopping_criterion",
|
|
|
- # type=str,
|
|
|
- # nargs=3,
|
|
|
- # default=("valid", "loss", "min"),
|
|
|
- # help="The criterion used for judging of early stopping. "
|
|
|
- # 'Give a pair referring the phase, "train" or "valid",'
|
|
|
- # 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--best_model_criterion",
|
|
|
- # type=str2triple_str,
|
|
|
- # nargs="+",
|
|
|
- # default=[
|
|
|
- # ("train", "loss", "min"),
|
|
|
- # ("valid", "loss", "min"),
|
|
|
- # ("train", "acc", "max"),
|
|
|
- # ("valid", "acc", "max"),
|
|
|
- # ],
|
|
|
- # help="The criterion used for judging of the best model. "
|
|
|
- # 'Give a pair referring the phase, "train" or "valid",'
|
|
|
- # 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--keep_nbest_models",
|
|
|
- # type=int,
|
|
|
- # nargs="+",
|
|
|
- # default=[10],
|
|
|
- # help="Remove previous snapshots excluding the n-best scored epochs",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--nbest_averaging_interval",
|
|
|
- # type=int,
|
|
|
- # default=0,
|
|
|
- # help="The epoch interval to apply model averaging and save nbest models",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--grad_clip",
|
|
|
- # type=float,
|
|
|
- # default=5.0,
|
|
|
- # help="Gradient norm threshold to clip",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--grad_clip_type",
|
|
|
- # type=float,
|
|
|
- # default=2.0,
|
|
|
- # help="The type of the used p-norm for gradient clip. Can be inf",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--grad_noise",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="The flag to switch to use noise injection to "
|
|
|
- # "gradients during training",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--accum_grad",
|
|
|
- # type=int,
|
|
|
- # default=1,
|
|
|
- # help="The number of gradient accumulation",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--bias_grad_times",
|
|
|
- # type=float,
|
|
|
- # default=1.0,
|
|
|
- # help="To scale the gradient of contextual related params",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--no_forward_run",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Just only iterating data loading without "
|
|
|
- # "model forwarding and training",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--resume",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Enable resuming if checkpoint is existing",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--train_dtype",
|
|
|
- # default="float32",
|
|
|
- # choices=["float16", "float32", "float64"],
|
|
|
- # help="Data type for training.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--use_amp",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--log_interval",
|
|
|
- # type=int_or_none,
|
|
|
- # default=None,
|
|
|
- # help="Show the logs every the number iterations in each epochs at the "
|
|
|
- # "training phase. If None is given, it is decided according the number "
|
|
|
- # "of training samples automatically .",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--use_tensorboard",
|
|
|
- # type=str2bool,
|
|
|
- # default=True,
|
|
|
- # help="Enable tensorboard logging",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--use_wandb",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Enable wandb logging",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--wandb_project",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="Specify wandb project",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--wandb_id",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="Specify wandb id",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--wandb_entity",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="Specify wandb entity",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--wandb_name",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="Specify wandb run name",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--wandb_model_log_interval",
|
|
|
- # type=int,
|
|
|
- # default=-1,
|
|
|
- # help="Set the model log period",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--detect_anomaly",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Set torch.autograd.set_detect_anomaly",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("Pretraining model related")
|
|
|
- # group.add_argument("--pretrain_path", help="This option is obsoleted")
|
|
|
- # group.add_argument(
|
|
|
- # "--init_param",
|
|
|
- # type=str,
|
|
|
- # action="append",
|
|
|
- # default=[],
|
|
|
- # help="Specify the file path used for initialization of parameters. "
|
|
|
- # "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
|
|
|
- # "where file_path is the model file path, "
|
|
|
- # "src_key specifies the key of model states to be used in the model file, "
|
|
|
- # "dst_key specifies the attribute of the model to be initialized, "
|
|
|
- # "and exclude_keys excludes keys of model states for the initialization."
|
|
|
- # "e.g.\n"
|
|
|
- # " # Load all parameters"
|
|
|
- # " --init_param some/where/model.pb\n"
|
|
|
- # " # Load only decoder parameters"
|
|
|
- # " --init_param some/where/model.pb:decoder:decoder\n"
|
|
|
- # " # Load only decoder parameters excluding decoder.embed"
|
|
|
- # " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
|
|
|
- # " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--ignore_init_mismatch",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Ignore size mismatch when loading pre-trained model",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--freeze_param",
|
|
|
- # type=str,
|
|
|
- # default=[],
|
|
|
- # action="append",
|
|
|
- # help="Freeze parameters",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("BatchSampler related")
|
|
|
- # group.add_argument(
|
|
|
- # "--num_iters_per_epoch",
|
|
|
- # type=int_or_none,
|
|
|
- # default=None,
|
|
|
- # help="Restrict the number of iterations for training per epoch",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--batch_size",
|
|
|
- # type=int,
|
|
|
- # default=20,
|
|
|
- # help="The mini-batch size used for training. Used if batch_type='unsorted',"
|
|
|
- # " 'sorted', or 'folded'.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--valid_batch_size",
|
|
|
- # type=int_or_none,
|
|
|
- # default=None,
|
|
|
- # help="If not given, the value of --batch_size is used",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--batch_bins",
|
|
|
- # type=int,
|
|
|
- # default=1000000,
|
|
|
- # help="The number of batch bins. Used if batch_type='length' or 'numel'",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--valid_batch_bins",
|
|
|
- # type=int_or_none,
|
|
|
- # default=None,
|
|
|
- # help="If not given, the value of --batch_bins is used",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group.add_argument("--train_shape_file", type=str, action="append", default=[])
|
|
|
- # group.add_argument("--valid_shape_file", type=str, action="append", default=[])
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("Sequence iterator related")
|
|
|
- # _batch_type_help = ""
|
|
|
- # for key, value in BATCH_TYPES.items():
|
|
|
- # _batch_type_help += f'"{key}":\n{value}\n'
|
|
|
- # group.add_argument(
|
|
|
- # "--batch_type",
|
|
|
- # type=str,
|
|
|
- # default="length",
|
|
|
- # choices=list(BATCH_TYPES),
|
|
|
- # help=_batch_type_help,
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--valid_batch_type",
|
|
|
- # type=str_or_none,
|
|
|
- # default=None,
|
|
|
- # choices=list(BATCH_TYPES) + [None],
|
|
|
- # help="If not given, the value of --batch_type is used",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--speech_length_min",
|
|
|
- # type=int,
|
|
|
- # default=-1,
|
|
|
- # help="speech length min",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--speech_length_max",
|
|
|
- # type=int,
|
|
|
- # default=-1,
|
|
|
- # help="speech length max",
|
|
|
- # )
|
|
|
- # group.add_argument("--fold_length", type=int, action="append", default=[])
|
|
|
- # group.add_argument(
|
|
|
- # "--sort_in_batch",
|
|
|
- # type=str,
|
|
|
- # default="descending",
|
|
|
- # choices=["descending", "ascending"],
|
|
|
- # help="Sort the samples in each mini-batches by the sample "
|
|
|
- # 'lengths. To enable this, "shape_file" must have the length information.',
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--sort_batch",
|
|
|
- # type=str,
|
|
|
- # default="descending",
|
|
|
- # choices=["descending", "ascending"],
|
|
|
- # help="Sort mini-batches by the sample lengths",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--multiple_iterator",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Use multiple iterator mode",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("Chunk iterator related")
|
|
|
- # group.add_argument(
|
|
|
- # "--chunk_length",
|
|
|
- # type=str_or_int,
|
|
|
- # default=500,
|
|
|
- # help="Specify chunk length. e.g. '300', '300,400,500', or '300-400'."
|
|
|
- # "If multiple numbers separated by command are given, "
|
|
|
- # "one of them is selected randomly for each samples. "
|
|
|
- # "If two numbers are given with '-', it indicates the range of the choices. "
|
|
|
- # "Note that if the sequence length is shorter than the all chunk_lengths, "
|
|
|
- # "the sample is discarded. ",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--chunk_shift_ratio",
|
|
|
- # type=float,
|
|
|
- # default=0.5,
|
|
|
- # help="Specify the shift width of chunks. If it's less than 1, "
|
|
|
- # "allows the overlapping and if bigger than 1, there are some gaps "
|
|
|
- # "between each chunk.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--num_cache_chunks",
|
|
|
- # type=int,
|
|
|
- # default=1024,
|
|
|
- # help="Shuffle in the specified number of chunks and generate mini-batches "
|
|
|
- # "More larger this value, more randomness can be obtained.",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("Dataset related")
|
|
|
- # _data_path_and_name_and_type_help = (
|
|
|
- # "Give three words splitted by comma. It's used for the training data. "
|
|
|
- # "e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. "
|
|
|
- # "The first value, some/path/a.scp, indicates the file path, "
|
|
|
- # "and the second, foo, is the key name used for the mini-batch data, "
|
|
|
- # "and the last, sound, decides the file type. "
|
|
|
- # "This option is repeatable, so you can input any number of features "
|
|
|
- # "for your task. Supported file types are as follows:\n\n"
|
|
|
- # )
|
|
|
- # for key, dic in DATA_TYPES.items():
|
|
|
- # _data_path_and_name_and_type_help += f'"{key}":\n{dic["help"]}\n\n'
|
|
|
- #
|
|
|
- # # for large dataset
|
|
|
- # group.add_argument(
|
|
|
- # "--dataset_type",
|
|
|
- # type=str,
|
|
|
- # default="small",
|
|
|
- # help="whether to use dataloader for large dataset",
|
|
|
- # )
|
|
|
- # parser.add_argument(
|
|
|
- # "--dataset_conf",
|
|
|
- # action=NestedDictAction,
|
|
|
- # default=dict(),
|
|
|
- # help=f"The keyword arguments for dataset",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--train_data_file",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="train_list for large dataset",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--valid_data_file",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="valid_list for large dataset",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group.add_argument(
|
|
|
- # "--train_data_path_and_name_and_type",
|
|
|
- # type=str2triple_str,
|
|
|
- # action="append",
|
|
|
- # default=[],
|
|
|
- # help=_data_path_and_name_and_type_help,
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--valid_data_path_and_name_and_type",
|
|
|
- # type=str2triple_str,
|
|
|
- # action="append",
|
|
|
- # default=[],
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--allow_variable_data_keys",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="Allow the arbitrary keys for mini-batch with ignoring "
|
|
|
- # "the task requirements",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--max_cache_size",
|
|
|
- # type=humanfriendly.parse_size,
|
|
|
- # default=0.0,
|
|
|
- # help="The maximum cache size for data loader. e.g. 10MB, 20GB.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--max_cache_fd",
|
|
|
- # type=int,
|
|
|
- # default=32,
|
|
|
- # help="The maximum number of file descriptors to be kept "
|
|
|
- # "as opened for ark files. "
|
|
|
- # "This feature is only valid when data type is 'kaldi_ark'.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--valid_max_cache_size",
|
|
|
- # type=humanfriendly_parse_size_or_none,
|
|
|
- # default=None,
|
|
|
- # help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. "
|
|
|
- # "If None, the 5 percent size of --max_cache_size",
|
|
|
- # )
|
|
|
- #
|
|
|
- # group = parser.add_argument_group("Optimizer related")
|
|
|
- # for i in range(1, cls.num_optimizers + 1):
|
|
|
- # suf = "" if i == 1 else str(i)
|
|
|
- # group.add_argument(
|
|
|
- # f"--optim{suf}",
|
|
|
- # type=lambda x: x.lower(),
|
|
|
- # default="adadelta",
|
|
|
- # choices=list(optim_classes),
|
|
|
- # help="The optimizer type",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # f"--optim{suf}_conf",
|
|
|
- # action=NestedDictAction,
|
|
|
- # default=dict(),
|
|
|
- # help="The keyword arguments for optimizer",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # f"--scheduler{suf}",
|
|
|
- # type=lambda x: str_or_none(x.lower()),
|
|
|
- # default=None,
|
|
|
- # choices=list(scheduler_classes) + [None],
|
|
|
- # help="The lr scheduler type",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # f"--scheduler{suf}_conf",
|
|
|
- # action=NestedDictAction,
|
|
|
- # default=dict(),
|
|
|
- # help="The keyword arguments for lr scheduler",
|
|
|
- # )
|
|
|
- #
|
|
|
- # # for training on PAI
|
|
|
- # group = parser.add_argument_group("PAI training related")
|
|
|
- # group.add_argument(
|
|
|
- # "--use_pai",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # help="flag to indicate whether training on PAI",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--simple_ddp",
|
|
|
- # type=str2bool,
|
|
|
- # default=False,
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--num_worker_count",
|
|
|
- # type=int,
|
|
|
- # default=1,
|
|
|
- # help="The number of machines on PAI.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--access_key_id",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="The username for oss.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--access_key_secret",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="The password for oss.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--endpoint",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="The endpoint for oss.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--bucket_name",
|
|
|
- # type=str,
|
|
|
- # default=None,
|
|
|
- # help="The bucket name for oss.",
|
|
|
- # )
|
|
|
- # group.add_argument(
|
|
|
- # "--oss_bucket",
|
|
|
- # default=None,
|
|
|
- # help="oss bucket.",
|
|
|
- # )
|
|
|
- #
|
|
|
- # cls.trainer.add_arguments(parser)
|
|
|
- # cls.add_task_arguments(parser)
|
|
|
- #
|
|
|
- # assert check_return_type(parser)
|
|
|
- # return parser
|
|
|
-
|
|
|
@classmethod
|
|
|
def build_optimizers(
|
|
|
cls,
|