build_trainer.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import os
  2. import yaml
  3. def update_dct(fin_configs, root):
  4. if root == {}:
  5. return {}
  6. for root_key, root_value in root.items():
  7. if not isinstance(root[root_key], dict):
  8. fin_configs[root_key] = root[root_key]
  9. else:
  10. if root_key in fin_configs.keys():
  11. result = update_dct(fin_configs[root_key], root[root_key])
  12. fin_configs[root_key] = result
  13. else:
  14. fin_configs[root_key] = root[root_key]
  15. return fin_configs
  16. def parse_args(mode):
  17. if mode == "asr":
  18. from funasr.tasks.asr import ASRTask as ASRTask
  19. elif mode == "paraformer":
  20. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  21. elif mode == "paraformer_vad_punc":
  22. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  23. elif mode == "uniasr":
  24. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  25. else:
  26. raise ValueError("Unknown mode: {}".format(mode))
  27. parser = ASRTask.get_parser()
  28. args = parser.parse_args()
  29. return args, ASRTask
  30. def build_trainer(modelscope_dict,
  31. data_dir,
  32. output_dir,
  33. train_set="train",
  34. dev_set="validation",
  35. distributed=False,
  36. dataset_type="small",
  37. batch_bins=None,
  38. max_epoch=None,
  39. optim=None,
  40. lr=None,
  41. scheduler=None,
  42. scheduler_conf=None,
  43. specaug=None,
  44. specaug_conf=None,
  45. param_dict=None,
  46. **kwargs):
  47. mode = modelscope_dict['mode']
  48. args, ASRTask = parse_args(mode=mode)
  49. # ddp related
  50. if args.local_rank is not None:
  51. distributed = True
  52. else:
  53. distributed = False
  54. args.local_rank = args.local_rank if args.local_rank is not None else 0
  55. local_rank = args.local_rank
  56. if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
  57. gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
  58. os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
  59. else:
  60. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
  61. config = modelscope_dict['am_model_config']
  62. finetune_config = modelscope_dict['finetune_config']
  63. init_param = modelscope_dict['init_model']
  64. cmvn_file = modelscope_dict['cmvn_file']
  65. seg_dict_file = modelscope_dict['seg_dict']
  66. # overwrite parameters
  67. with open(config) as f:
  68. configs = yaml.safe_load(f)
  69. with open(finetune_config) as f:
  70. finetune_configs = yaml.safe_load(f)
  71. # set data_types
  72. if dataset_type == "large":
  73. finetune_configs["dataset_conf"]["data_types"] = "sound,text"
  74. finetune_configs = update_dct(configs, finetune_configs)
  75. for key, value in finetune_configs.items():
  76. if hasattr(args, key):
  77. setattr(args, key, value)
  78. # prepare data
  79. args.dataset_type = dataset_type
  80. if args.dataset_type == "small":
  81. args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
  82. ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
  83. args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
  84. ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
  85. elif args.dataset_type == "large":
  86. args.train_data_file = None
  87. args.valid_data_file = None
  88. else:
  89. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  90. args.init_param = [init_param]
  91. args.cmvn_file = cmvn_file
  92. if os.path.exists(seg_dict_file):
  93. args.seg_dict_file = seg_dict_file
  94. else:
  95. args.seg_dict_file = None
  96. args.data_dir = data_dir
  97. args.train_set = train_set
  98. args.dev_set = dev_set
  99. args.output_dir = output_dir
  100. args.gpu_id = args.local_rank
  101. args.config = finetune_config
  102. if optim is not None:
  103. args.optim = optim
  104. if lr is not None:
  105. args.optim_conf["lr"] = lr
  106. if scheduler is not None:
  107. args.scheduler = scheduler
  108. if scheduler_conf is not None:
  109. args.scheduler_conf = scheduler_conf
  110. if specaug is not None:
  111. args.specaug = specaug
  112. if specaug_conf is not None:
  113. args.specaug_conf = specaug_conf
  114. if max_epoch is not None:
  115. args.max_epoch = max_epoch
  116. if batch_bins is not None:
  117. if args.dataset_type == "small":
  118. args.batch_bins = batch_bins
  119. elif args.dataset_type == "large":
  120. args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
  121. else:
  122. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  123. if args.normalize in ["null", "none", "None"]:
  124. args.normalize = None
  125. if args.patience in ["null", "none", "None"]:
  126. args.patience = None
  127. args.local_rank = local_rank
  128. args.distributed = distributed
  129. ASRTask.finetune_args = args
  130. return ASRTask