build_trainer.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import os
  2. import yaml
  3. def update_dct(fin_configs, root):
  4. if root == {}:
  5. return {}
  6. for root_key, root_value in root.items():
  7. if not isinstance(root[root_key], dict):
  8. fin_configs[root_key] = root[root_key]
  9. else:
  10. if root_key in fin_configs.keys():
  11. result = update_dct(fin_configs[root_key], root[root_key])
  12. fin_configs[root_key] = result
  13. else:
  14. fin_configs[root_key] = root[root_key]
  15. return fin_configs
  16. def parse_args(mode):
  17. if mode == "asr":
  18. from funasr.tasks.asr import ASRTask as ASRTask
  19. elif mode == "paraformer":
  20. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  21. elif mode == "paraformer_streaming":
  22. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  23. elif mode == "paraformer_vad_punc":
  24. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  25. elif mode == "uniasr":
  26. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  27. elif mode == "mfcca":
  28. from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
  29. elif mode == "tp":
  30. from funasr.tasks.asr import ASRTaskAligner as ASRTask
  31. else:
  32. raise ValueError("Unknown mode: {}".format(mode))
  33. parser = ASRTask.get_parser()
  34. args = parser.parse_args()
  35. return args, ASRTask
  36. def build_trainer(modelscope_dict,
  37. data_dir,
  38. output_dir,
  39. train_set="train",
  40. dev_set="validation",
  41. distributed=False,
  42. dataset_type="small",
  43. batch_bins=None,
  44. max_epoch=None,
  45. optim=None,
  46. lr=None,
  47. scheduler=None,
  48. scheduler_conf=None,
  49. specaug=None,
  50. specaug_conf=None,
  51. param_dict=None,
  52. **kwargs):
  53. mode = modelscope_dict['mode']
  54. args, ASRTask = parse_args(mode=mode)
  55. # ddp related
  56. if args.local_rank is not None:
  57. distributed = True
  58. else:
  59. distributed = False
  60. args.local_rank = args.local_rank if args.local_rank is not None else 0
  61. local_rank = args.local_rank
  62. if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
  63. gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
  64. os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
  65. else:
  66. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
  67. config = modelscope_dict['am_model_config']
  68. finetune_config = modelscope_dict['finetune_config']
  69. init_param = modelscope_dict['init_model']
  70. cmvn_file = modelscope_dict['cmvn_file']
  71. seg_dict_file = modelscope_dict['seg_dict']
  72. # overwrite parameters
  73. with open(config) as f:
  74. configs = yaml.safe_load(f)
  75. with open(finetune_config) as f:
  76. finetune_configs = yaml.safe_load(f)
  77. # set data_types
  78. if dataset_type == "large":
  79. finetune_configs["dataset_conf"]["data_types"] = "sound,text"
  80. finetune_configs = update_dct(configs, finetune_configs)
  81. for key, value in finetune_configs.items():
  82. if hasattr(args, key):
  83. setattr(args, key, value)
  84. # prepare data
  85. args.dataset_type = dataset_type
  86. if args.dataset_type == "small":
  87. args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
  88. ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
  89. args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
  90. ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
  91. elif args.dataset_type == "large":
  92. args.train_data_file = None
  93. args.valid_data_file = None
  94. else:
  95. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  96. args.init_param = [init_param]
  97. args.cmvn_file = cmvn_file
  98. if os.path.exists(seg_dict_file):
  99. args.seg_dict_file = seg_dict_file
  100. else:
  101. args.seg_dict_file = None
  102. args.data_dir = data_dir
  103. args.train_set = train_set
  104. args.dev_set = dev_set
  105. args.output_dir = output_dir
  106. args.gpu_id = args.local_rank
  107. args.config = finetune_config
  108. if optim is not None:
  109. args.optim = optim
  110. if lr is not None:
  111. args.optim_conf["lr"] = lr
  112. if scheduler is not None:
  113. args.scheduler = scheduler
  114. if scheduler_conf is not None:
  115. args.scheduler_conf = scheduler_conf
  116. if specaug is not None:
  117. args.specaug = specaug
  118. if specaug_conf is not None:
  119. args.specaug_conf = specaug_conf
  120. if max_epoch is not None:
  121. args.max_epoch = max_epoch
  122. if batch_bins is not None:
  123. if args.dataset_type == "small":
  124. args.batch_bins = batch_bins
  125. elif args.dataset_type == "large":
  126. args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
  127. else:
  128. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  129. if args.normalize in ["null", "none", "None"]:
  130. args.normalize = None
  131. if args.patience in ["null", "none", "None"]:
  132. args.patience = None
  133. args.local_rank = local_rank
  134. args.distributed = distributed
  135. ASRTask.finetune_args = args
  136. return ASRTask