build_trainer.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import os
  2. import yaml
  3. def update_dct(fin_configs, root):
  4. if root == {}:
  5. return {}
  6. for root_key, root_value in root.items():
  7. if not isinstance(root[root_key], dict):
  8. fin_configs[root_key] = root[root_key]
  9. else:
  10. if root_key in fin_configs.keys():
  11. result = update_dct(fin_configs[root_key], root[root_key])
  12. fin_configs[root_key] = result
  13. else:
  14. fin_configs[root_key] = root[root_key]
  15. return fin_configs
  16. def parse_args(mode):
  17. if mode == "asr":
  18. from funasr.tasks.asr import ASRTask as ASRTask
  19. elif mode == "paraformer":
  20. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  21. elif mode == "paraformer_vad_punc":
  22. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  23. elif mode == "uniasr":
  24. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  25. elif mode == "mfcca":
  26. from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
  27. else:
  28. raise ValueError("Unknown mode: {}".format(mode))
  29. parser = ASRTask.get_parser()
  30. args = parser.parse_args()
  31. return args, ASRTask
  32. def build_trainer(modelscope_dict,
  33. data_dir,
  34. output_dir,
  35. train_set="train",
  36. dev_set="validation",
  37. distributed=False,
  38. dataset_type="small",
  39. batch_bins=None,
  40. max_epoch=None,
  41. optim=None,
  42. lr=None,
  43. scheduler=None,
  44. scheduler_conf=None,
  45. specaug=None,
  46. specaug_conf=None,
  47. param_dict=None,
  48. **kwargs):
  49. mode = modelscope_dict['mode']
  50. args, ASRTask = parse_args(mode=mode)
  51. # ddp related
  52. if args.local_rank is not None:
  53. distributed = True
  54. else:
  55. distributed = False
  56. args.local_rank = args.local_rank if args.local_rank is not None else 0
  57. local_rank = args.local_rank
  58. if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
  59. gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
  60. os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
  61. else:
  62. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
  63. config = modelscope_dict['am_model_config']
  64. finetune_config = modelscope_dict['finetune_config']
  65. init_param = modelscope_dict['init_model']
  66. cmvn_file = modelscope_dict['cmvn_file']
  67. seg_dict_file = modelscope_dict['seg_dict']
  68. # overwrite parameters
  69. with open(config) as f:
  70. configs = yaml.safe_load(f)
  71. with open(finetune_config) as f:
  72. finetune_configs = yaml.safe_load(f)
  73. # set data_types
  74. if dataset_type == "large":
  75. finetune_configs["dataset_conf"]["data_types"] = "sound,text"
  76. finetune_configs = update_dct(configs, finetune_configs)
  77. for key, value in finetune_configs.items():
  78. if hasattr(args, key):
  79. setattr(args, key, value)
  80. # prepare data
  81. args.dataset_type = dataset_type
  82. if args.dataset_type == "small":
  83. args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
  84. ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
  85. args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
  86. ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
  87. elif args.dataset_type == "large":
  88. args.train_data_file = None
  89. args.valid_data_file = None
  90. else:
  91. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  92. args.init_param = [init_param]
  93. args.cmvn_file = cmvn_file
  94. if os.path.exists(seg_dict_file):
  95. args.seg_dict_file = seg_dict_file
  96. else:
  97. args.seg_dict_file = None
  98. args.data_dir = data_dir
  99. args.train_set = train_set
  100. args.dev_set = dev_set
  101. args.output_dir = output_dir
  102. args.gpu_id = args.local_rank
  103. args.config = finetune_config
  104. if optim is not None:
  105. args.optim = optim
  106. if lr is not None:
  107. args.optim_conf["lr"] = lr
  108. if scheduler is not None:
  109. args.scheduler = scheduler
  110. if scheduler_conf is not None:
  111. args.scheduler_conf = scheduler_conf
  112. if specaug is not None:
  113. args.specaug = specaug
  114. if specaug_conf is not None:
  115. args.specaug_conf = specaug_conf
  116. if max_epoch is not None:
  117. args.max_epoch = max_epoch
  118. if batch_bins is not None:
  119. if args.dataset_type == "small":
  120. args.batch_bins = batch_bins
  121. elif args.dataset_type == "large":
  122. args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
  123. else:
  124. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  125. if args.normalize in ["null", "none", "None"]:
  126. args.normalize = None
  127. if args.patience in ["null", "none", "None"]:
  128. args.patience = None
  129. args.local_rank = local_rank
  130. args.distributed = distributed
  131. ASRTask.finetune_args = args
  132. return ASRTask