build_trainer.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import os
  2. import yaml
  3. def update_dct(fin_configs, root):
  4. if root == {}:
  5. return {}
  6. for root_key, root_value in root.items():
  7. if not isinstance(root[root_key], dict):
  8. fin_configs[root_key] = root[root_key]
  9. else:
  10. if root_key in fin_configs.keys():
  11. result = update_dct(fin_configs[root_key], root[root_key])
  12. fin_configs[root_key] = result
  13. else:
  14. fin_configs[root_key] = root[root_key]
  15. return fin_configs
  16. def parse_args(mode):
  17. if mode == "asr":
  18. from funasr.tasks.asr import ASRTask as ASRTask
  19. elif mode == "paraformer":
  20. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  21. elif mode == "paraformer_vad_punc":
  22. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  23. elif mode == "uniasr":
  24. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  25. elif mode == "mfcca":
  26. from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
  27. elif mode == "tp":
  28. from funasr.tasks.asr import ASRTaskAligner as ASRTask
  29. else:
  30. raise ValueError("Unknown mode: {}".format(mode))
  31. parser = ASRTask.get_parser()
  32. args = parser.parse_args()
  33. return args, ASRTask
  34. def build_trainer(modelscope_dict,
  35. data_dir,
  36. output_dir,
  37. train_set="train",
  38. dev_set="validation",
  39. distributed=False,
  40. dataset_type="small",
  41. batch_bins=None,
  42. max_epoch=None,
  43. optim=None,
  44. lr=None,
  45. scheduler=None,
  46. scheduler_conf=None,
  47. specaug=None,
  48. specaug_conf=None,
  49. param_dict=None,
  50. **kwargs):
  51. mode = modelscope_dict['mode']
  52. args, ASRTask = parse_args(mode=mode)
  53. # ddp related
  54. if args.local_rank is not None:
  55. distributed = True
  56. else:
  57. distributed = False
  58. args.local_rank = args.local_rank if args.local_rank is not None else 0
  59. local_rank = args.local_rank
  60. if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
  61. gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
  62. os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
  63. else:
  64. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
  65. config = modelscope_dict['am_model_config']
  66. finetune_config = modelscope_dict['finetune_config']
  67. init_param = modelscope_dict['init_model']
  68. cmvn_file = modelscope_dict['cmvn_file']
  69. seg_dict_file = modelscope_dict['seg_dict']
  70. # overwrite parameters
  71. with open(config) as f:
  72. configs = yaml.safe_load(f)
  73. with open(finetune_config) as f:
  74. finetune_configs = yaml.safe_load(f)
  75. # set data_types
  76. if dataset_type == "large":
  77. finetune_configs["dataset_conf"]["data_types"] = "sound,text"
  78. finetune_configs = update_dct(configs, finetune_configs)
  79. for key, value in finetune_configs.items():
  80. if hasattr(args, key):
  81. setattr(args, key, value)
  82. # prepare data
  83. args.dataset_type = dataset_type
  84. if args.dataset_type == "small":
  85. args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
  86. ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
  87. args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
  88. ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
  89. elif args.dataset_type == "large":
  90. args.train_data_file = None
  91. args.valid_data_file = None
  92. else:
  93. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  94. args.init_param = [init_param]
  95. args.cmvn_file = cmvn_file
  96. if os.path.exists(seg_dict_file):
  97. args.seg_dict_file = seg_dict_file
  98. else:
  99. args.seg_dict_file = None
  100. args.data_dir = data_dir
  101. args.train_set = train_set
  102. args.dev_set = dev_set
  103. args.output_dir = output_dir
  104. args.gpu_id = args.local_rank
  105. args.config = finetune_config
  106. if optim is not None:
  107. args.optim = optim
  108. if lr is not None:
  109. args.optim_conf["lr"] = lr
  110. if scheduler is not None:
  111. args.scheduler = scheduler
  112. if scheduler_conf is not None:
  113. args.scheduler_conf = scheduler_conf
  114. if specaug is not None:
  115. args.specaug = specaug
  116. if specaug_conf is not None:
  117. args.specaug_conf = specaug_conf
  118. if max_epoch is not None:
  119. args.max_epoch = max_epoch
  120. if batch_bins is not None:
  121. if args.dataset_type == "small":
  122. args.batch_bins = batch_bins
  123. elif args.dataset_type == "large":
  124. args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
  125. else:
  126. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  127. if args.normalize in ["null", "none", "None"]:
  128. args.normalize = None
  129. if args.patience in ["null", "none", "None"]:
  130. args.patience = None
  131. args.local_rank = local_rank
  132. args.distributed = distributed
  133. ASRTask.finetune_args = args
  134. return ASRTask