build_trainer.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import os
  2. import yaml
  3. def update_dct(fin_configs, root):
  4. if root == {}:
  5. return {}
  6. for root_key, root_value in root.items():
  7. if not isinstance(root[root_key], dict):
  8. fin_configs[root_key] = root[root_key]
  9. else:
  10. if root_key in fin_configs.keys():
  11. result = update_dct(fin_configs[root_key], root[root_key])
  12. fin_configs[root_key] = result
  13. else:
  14. fin_configs[root_key] = root[root_key]
  15. return fin_configs
  16. def parse_args(mode):
  17. if mode == "asr":
  18. from funasr.tasks.asr import ASRTask as ASRTask
  19. elif mode == "paraformer":
  20. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  21. elif mode == "paraformer_streaming":
  22. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  23. elif mode == "paraformer_vad_punc":
  24. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  25. elif mode == "uniasr":
  26. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  27. elif mode == "mfcca":
  28. from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
  29. elif mode == "tp":
  30. from funasr.tasks.asr import ASRTaskAligner as ASRTask
  31. else:
  32. raise ValueError("Unknown mode: {}".format(mode))
  33. parser = ASRTask.get_parser()
  34. args = parser.parse_args()
  35. return args, ASRTask
  36. def build_trainer(modelscope_dict,
  37. data_dir,
  38. output_dir,
  39. train_set="train",
  40. dev_set="validation",
  41. distributed=False,
  42. dataset_type="small",
  43. batch_bins=None,
  44. max_epoch=None,
  45. optim=None,
  46. lr=None,
  47. scheduler=None,
  48. scheduler_conf=None,
  49. specaug=None,
  50. specaug_conf=None,
  51. mate_params=None,
  52. **kwargs):
  53. mode = modelscope_dict['mode']
  54. args, ASRTask = parse_args(mode=mode)
  55. # ddp related
  56. if args.local_rank is not None:
  57. distributed = True
  58. else:
  59. distributed = False
  60. args.local_rank = args.local_rank if args.local_rank is not None else 0
  61. local_rank = args.local_rank
  62. if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
  63. gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
  64. os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
  65. else:
  66. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
  67. config = modelscope_dict['am_model_config']
  68. finetune_config = modelscope_dict['finetune_config']
  69. init_param = modelscope_dict['init_model']
  70. cmvn_file = modelscope_dict['cmvn_file']
  71. seg_dict_file = modelscope_dict['seg_dict']
  72. # overwrite parameters
  73. with open(config) as f:
  74. configs = yaml.safe_load(f)
  75. with open(finetune_config) as f:
  76. finetune_configs = yaml.safe_load(f)
  77. # set data_types
  78. if dataset_type == "large":
  79. # finetune_configs["dataset_conf"]["data_types"] = "sound,text"
  80. if 'data_types' not in finetune_configs['dataset_conf']:
  81. finetune_configs["dataset_conf"]["data_types"] = "sound,text"
  82. finetune_configs = update_dct(configs, finetune_configs)
  83. for key, value in finetune_configs.items():
  84. if hasattr(args, key):
  85. setattr(args, key, value)
  86. if mate_params is not None:
  87. for key, value in mate_params.items():
  88. if hasattr(args, key):
  89. setattr(args, key, value)
  90. if mate_params is not None and "lora_params" in mate_params:
  91. lora_params = mate_params['lora_params']
  92. configs['encoder_conf'].update(lora_params)
  93. configs['decoder_conf'].update(lora_params)
  94. # prepare data
  95. args.dataset_type = dataset_type
  96. if args.dataset_type == "small":
  97. args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
  98. ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
  99. args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
  100. ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
  101. elif args.dataset_type == "large":
  102. args.train_data_file = None
  103. args.valid_data_file = None
  104. else:
  105. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  106. args.init_param = [init_param]
  107. if mate_params is not None and "init_param" in mate_params:
  108. if len(mate_params["init_param"]) != 0:
  109. args.init_param = mate_params["init_param"]
  110. args.cmvn_file = cmvn_file
  111. if os.path.exists(seg_dict_file):
  112. args.seg_dict_file = seg_dict_file
  113. else:
  114. args.seg_dict_file = None
  115. args.data_dir = data_dir
  116. args.train_set = train_set
  117. args.dev_set = dev_set
  118. args.output_dir = output_dir
  119. args.gpu_id = args.local_rank
  120. args.config = finetune_config
  121. if optim is not None:
  122. args.optim = optim
  123. if lr is not None:
  124. args.optim_conf["lr"] = lr
  125. if scheduler is not None:
  126. args.scheduler = scheduler
  127. if scheduler_conf is not None:
  128. args.scheduler_conf = scheduler_conf
  129. if specaug is not None:
  130. args.specaug = specaug
  131. if specaug_conf is not None:
  132. args.specaug_conf = specaug_conf
  133. if max_epoch is not None:
  134. args.max_epoch = max_epoch
  135. if batch_bins is not None:
  136. if args.dataset_type == "small":
  137. args.batch_bins = batch_bins
  138. elif args.dataset_type == "large":
  139. args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
  140. else:
  141. raise ValueError(f"Not supported dataset_type={args.dataset_type}")
  142. if args.normalize in ["null", "none", "None"]:
  143. args.normalize = None
  144. if args.patience in ["null", "none", "None"]:
  145. args.patience = None
  146. args.local_rank = local_rank
  147. args.distributed = distributed
  148. ASRTask.finetune_args = args
  149. return ASRTask