speech_asr 2 سال پیش
والد
کامیت
bd7455ec7d
3فایلهای تغییر یافته به همراه367 افزوده شده و 3 حذف شده
  1. 326 0
      funasr/bin/train.py
  2. 3 3
      funasr/tasks/abs_task.py
  3. 38 0
      funasr/utils/build_distributed.py

+ 326 - 0
funasr/bin/train.py

@@ -0,0 +1,326 @@
+import sys
+
+import torch
+
+from funasr.utils import config_argparse
+from funasr.utils.build_distributed import build_distributed
+from funasr.utils.types import str2bool
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="FunASR Common Training Parser",
+    )
+
+    # common configuration
+    parser.add_argument("--output_dir", help="model save path")
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+
+    # ddp related
+    parser.add_argument(
+        "--dist_backend",
+        default="nccl",
+        type=str,
+        help="distributed backend",
+    )
+    parser.add_argument(
+        "--dist_init_method",
+        type=str,
+        default="env://",
+        help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
+             '"WORLD_SIZE", and "RANK" are referred.',
+    )
+    parser.add_argument(
+        "--dist_world_size",
+        default=None,
+        help="number of nodes for distributed training",
+    )
+    parser.add_argument(
+        "--dist_rank",
+        default=None,
+        help="node rank for distributed training",
+    )
+    parser.add_argument(
+        "--local_rank",
+        default=None,
+        help="local rank for distributed training",
+    )
+    parser.add_argument(
+        "--unused_parameters",
+        type=str2bool,
+        default=False,
+        help="Whether to use the find_unused_parameters in "
+             "torch.nn.parallel.DistributedDataParallel ",
+    )
+
+    # cudnn related
+    parser.add_argument(
+        "--cudnn_enabled",
+        type=str2bool,
+        default=torch.backends.cudnn.enabled,
+        help="Enable CUDNN",
+    )
+    parser.add_argument(
+        "--cudnn_benchmark",
+        type=str2bool,
+        default=torch.backends.cudnn.benchmark,
+        help="Enable cudnn-benchmark mode",
+    )
+    parser.add_argument(
+        "--cudnn_deterministic",
+        type=str2bool,
+        default=True,
+        help="Enable cudnn-deterministic mode",
+    )
+
+    # trainer related
+    parser.add_argument(
+        "--max_epoch",
+        type=int,
+        default=40,
+        help="The maximum number epoch to train",
+    )
+    parser.add_argument(
+        "--max_update",
+        type=int,
+        default=sys.maxsize,
+        help="The maximum number update step to train",
+    )
+    parser.add_argument(
+        "--batch_interval",
+        type=int,
+        default=10000,
+        help="The batch interval for saving model.",
+    )
+    parser.add_argument(
+        "--patience",
+        default=None,
+        help="Number of epochs to wait without improvement "
+             "before stopping the training",
+    )
+    parser.add_argument(
+        "--val_scheduler_criterion",
+        type=str,
+        nargs=2,
+        default=("valid", "loss"),
+        help="The criterion used for the value given to the lr scheduler. "
+             'Give a pair referring the phase, "train" or "valid",'
+             'and the criterion name. The mode specifying "min" or "max" can '
+             "be changed by --scheduler_conf",
+    )
+    parser.add_argument(
+        "--early_stopping_criterion",
+        type=str,
+        nargs=3,
+        default=("valid", "loss", "min"),
+        help="The criterion used for judging of early stopping. "
+             'Give a pair referring the phase, "train" or "valid",'
+             'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
+    )
+    parser.add_argument(
+        "--best_model_criterion",
+        nargs="+",
+        default=[
+            ("train", "loss", "min"),
+            ("valid", "loss", "min"),
+            ("train", "acc", "max"),
+            ("valid", "acc", "max"),
+        ],
+        help="The criterion used for judging of the best model. "
+             'Give a pair referring the phase, "train" or "valid",'
+             'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
+    )
+    parser.add_argument(
+        "--keep_nbest_models",
+        type=int,
+        nargs="+",
+        default=[10],
+        help="Remove previous snapshots excluding the n-best scored epochs",
+    )
+    parser.add_argument(
+        "--nbest_averaging_interval",
+        type=int,
+        default=0,
+        help="The epoch interval to apply model averaging and save nbest models",
+    )
+    parser.add_argument(
+        "--grad_clip",
+        type=float,
+        default=5.0,
+        help="Gradient norm threshold to clip",
+    )
+    parser.add_argument(
+        "--grad_clip_type",
+        type=float,
+        default=2.0,
+        help="The type of the used p-norm for gradient clip. Can be inf",
+    )
+    parser.add_argument(
+        "--grad_noise",
+        type=str2bool,
+        default=False,
+        help="The flag to switch to use noise injection to "
+             "gradients during training",
+    )
+    parser.add_argument(
+        "--accum_grad",
+        type=int,
+        default=1,
+        help="The number of gradient accumulation",
+    )
+    parser.add_argument(
+        "--resume",
+        type=str2bool,
+        default=False,
+        help="Enable resuming if checkpoint is existing",
+    )
+    parser.add_argument(
+        "--use_amp",
+        type=str2bool,
+        default=False,
+        help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
+    )
+    parser.add_argument(
+        "--log_interval",
+        default=None,
+        help="Show the logs every the number iterations in each epochs at the "
+             "training phase. If None is given, it is decided according the number "
+             "of training samples automatically .",
+    )
+
+    # pretrained model related
+    parser.add_argument(
+        "--init_param",
+        type=str,
+        default=[],
+        nargs="*",
+        help="Specify the file path used for initialization of parameters. "
+             "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
+             "where file_path is the model file path, "
+             "src_key specifies the key of model states to be used in the model file, "
+             "dst_key specifies the attribute of the model to be initialized, "
+             "and exclude_keys excludes keys of model states for the initialization."
+             "e.g.\n"
+             "  # Load all parameters"
+             "  --init_param some/where/model.pb\n"
+             "  # Load only decoder parameters"
+             "  --init_param some/where/model.pb:decoder:decoder\n"
+             "  # Load only decoder parameters excluding decoder.embed"
+             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
+             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
+    )
+    parser.add_argument(
+        "--ignore_init_mismatch",
+        type=str2bool,
+        default=False,
+        help="Ignore size mismatch when loading pre-trained model",
+    )
+    parser.add_argument(
+        "--freeze_param",
+        type=str,
+        default=[],
+        nargs="*",
+        help="Freeze parameters",
+    )
+
+    # dataset related
+    parser.add_argument(
+        "--dataset_type",
+        type=str,
+        default="small",
+        help="whether to use dataloader for large dataset",
+    )
+    parser.add_argument(
+        "--train_data_file",
+        type=str,
+        default=None,
+        help="train_list for large dataset",
+    )
+    parser.add_argument(
+        "--valid_data_file",
+        type=str,
+        default=None,
+        help="valid_list for large dataset",
+    )
+    parser.add_argument(
+        "--train_data_path_and_name_and_type",
+        action="append",
+        default=[],
+        help="e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. ",
+    )
+    parser.add_argument(
+        "--valid_data_path_and_name_and_type",
+        action="append",
+        default=[],
+    )
+
+    # pai related
+    parser.add_argument(
+        "--use_pai",
+        type=str2bool,
+        default=False,
+        help="flag to indicate whether training on PAI",
+    )
+    parser.add_argument(
+        "--simple_ddp",
+        type=str2bool,
+        default=False,
+    )
+    parser.add_argument(
+        "--num_worker_count",
+        type=int,
+        default=1,
+        help="The number of machines on PAI.",
+    )
+    parser.add_argument(
+        "--access_key_id",
+        type=str,
+        default=None,
+        help="The username for oss.",
+    )
+    parser.add_argument(
+        "--access_key_secret",
+        type=str,
+        default=None,
+        help="The password for oss.",
+    )
+    parser.add_argument(
+        "--endpoint",
+        type=str,
+        default=None,
+        help="The endpoint for oss.",
+    )
+    parser.add_argument(
+        "--bucket_name",
+        type=str,
+        default=None,
+        help="The bucket name for oss.",
+    )
+    parser.add_argument(
+        "--oss_bucket",
+        default=None,
+        help="oss bucket.",
+    )
+
+    # task related
+    parser.add_argument("--task_name", help="for different task")
+
+    return parser
+
+
+if __name__ == '__main__':
+    parser = get_parser()
+    args = parser.parse_args()
+
+    args.distributed = args.dist_world_size > 1
+    distributed_option = build_distributed(args)
+
+    #
+
+

+ 3 - 3
funasr/tasks/abs_task.py

@@ -30,6 +30,7 @@ import torch.multiprocessing
 import torch.nn
 import torch.nn
 import torch.optim
 import torch.optim
 import yaml
 import yaml
+from funasr.train.abs_espnet_model import AbsESPnetModel
 from torch.utils.data import DataLoader
 from torch.utils.data import DataLoader
 from typeguard import check_argument_types
 from typeguard import check_argument_types
 from typeguard import check_return_type
 from typeguard import check_return_type
@@ -44,19 +45,18 @@ from funasr.iterators.chunk_iter_factory import ChunkIterFactory
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
 from funasr.main_funcs.collect_stats import collect_stats
 from funasr.main_funcs.collect_stats import collect_stats
-from funasr.optimizers.sgd import SGD
 from funasr.optimizers.fairseq_adam import FairseqAdam
 from funasr.optimizers.fairseq_adam import FairseqAdam
+from funasr.optimizers.sgd import SGD
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
 from funasr.samplers.build_batch_sampler import build_batch_sampler
 from funasr.samplers.build_batch_sampler import build_batch_sampler
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.schedulers.noam_lr import NoamLR
 from funasr.schedulers.noam_lr import NoamLR
-from funasr.schedulers.warmup_lr import WarmupLR
 from funasr.schedulers.tri_stage_scheduler import TriStageLR
 from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.class_choices import ClassChoices
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.trainer import Trainer
 from funasr.train.trainer import Trainer

+ 38 - 0
funasr/utils/build_distributed.py

@@ -0,0 +1,38 @@
+import logging
+import os
+
+import torch
+
+from funasr.train.distributed_utils import DistributedOption
+from funasr.utils.build_dataclass import build_dataclass
+
+
+def build_distributed(args):
+    distributed_option = build_dataclass(DistributedOption, args)
+    if args.use_pai:
+        distributed_option.init_options_pai()
+        distributed_option.init_torch_distributed_pai(args)
+    elif not args.simple_ddp:
+        distributed_option.init_torch_distributed(args)
+    elif args.distributed and args.simple_ddp:
+        distributed_option.init_torch_distributed_pai(args)
+        args.ngpu = torch.distributed.get_world_size()
+
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+    if not distributed_option.distributed or distributed_option.dist_rank == 0:
+        logging.basicConfig(
+            level="INFO",
+            format=f"[{os.uname()[1].split('.')[0]}]"
+                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+        )
+    else:
+        logging.basicConfig(
+            level="ERROR",
+            format=f"[{os.uname()[1].split('.')[0]}]"
+                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+        )
+    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+                                                                   distributed_option.dist_rank,
+                                                                   distributed_option.local_rank))
+    return distributed_option