|
@@ -0,0 +1,145 @@
|
|
|
|
|
+import os
|
|
|
|
|
+
|
|
|
|
|
+import yaml
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def update_dct(fin_configs, root):
|
|
|
|
|
+ if root == {}:
|
|
|
|
|
+ return {}
|
|
|
|
|
+ for root_key, root_value in root.items():
|
|
|
|
|
+ if not isinstance(root[root_key], dict):
|
|
|
|
|
+ fin_configs[root_key] = root[root_key]
|
|
|
|
|
+ else:
|
|
|
|
|
+ if root_key in fin_configs.keys():
|
|
|
|
|
+ result = update_dct(fin_configs[root_key], root[root_key])
|
|
|
|
|
+ fin_configs[root_key] = result
|
|
|
|
|
+ else:
|
|
|
|
|
+ fin_configs[root_key] = root[root_key]
|
|
|
|
|
+ return fin_configs
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def parse_args(mode):
|
|
|
|
|
+ if mode == "asr":
|
|
|
|
|
+ from funasr.tasks.asr import ASRTask as ASRTask
|
|
|
|
|
+ elif mode == "paraformer":
|
|
|
|
|
+ from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
|
|
|
|
+ elif mode == "paraformer_vad_punc":
|
|
|
|
|
+ from funasr.tasks.asr import ASRTaskParaformer as ASRTask
|
|
|
|
|
+ elif mode == "uniasr":
|
|
|
|
|
+ from funasr.tasks.asr import ASRTaskUniASR as ASRTask
|
|
|
|
|
+ elif mode == "mfcca":
|
|
|
|
|
+ from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
|
|
|
|
|
+ elif mode == "tp":
|
|
|
|
|
+ from funasr.tasks.asr import ASRTaskAligner as ASRTask
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError("Unknown mode: {}".format(mode))
|
|
|
|
|
+ parser = ASRTask.get_parser()
|
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
|
+ return args, ASRTask
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def build_trainer(modelscope_dict,
|
|
|
|
|
+ data_dir,
|
|
|
|
|
+ output_dir,
|
|
|
|
|
+ train_set="train",
|
|
|
|
|
+ dev_set="validation",
|
|
|
|
|
+ distributed=False,
|
|
|
|
|
+ dataset_type="small",
|
|
|
|
|
+ batch_bins=None,
|
|
|
|
|
+ max_epoch=None,
|
|
|
|
|
+ optim=None,
|
|
|
|
|
+ lr=None,
|
|
|
|
|
+ scheduler=None,
|
|
|
|
|
+ scheduler_conf=None,
|
|
|
|
|
+ specaug=None,
|
|
|
|
|
+ specaug_conf=None,
|
|
|
|
|
+ param_dict=None,
|
|
|
|
|
+ **kwargs):
|
|
|
|
|
+ mode = modelscope_dict['mode']
|
|
|
|
|
+ args, ASRTask = parse_args(mode=mode)
|
|
|
|
|
+ # ddp related
|
|
|
|
|
+ if args.local_rank is not None:
|
|
|
|
|
+ distributed = True
|
|
|
|
|
+ else:
|
|
|
|
|
+ distributed = False
|
|
|
|
|
+ args.local_rank = args.local_rank if args.local_rank is not None else 0
|
|
|
|
|
+ local_rank = args.local_rank
|
|
|
|
|
+ if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
|
|
|
|
|
+ gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
|
|
|
|
|
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
|
|
|
|
|
+ else:
|
|
|
|
|
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
|
|
|
|
|
+
|
|
|
|
|
+ config = modelscope_dict['am_model_config']
|
|
|
|
|
+ finetune_config = modelscope_dict['finetune_config']
|
|
|
|
|
+ init_param = modelscope_dict['init_model']
|
|
|
|
|
+ cmvn_file = modelscope_dict['cmvn_file']
|
|
|
|
|
+ seg_dict_file = modelscope_dict['seg_dict']
|
|
|
|
|
+
|
|
|
|
|
+ # overwrite parameters
|
|
|
|
|
+ with open(config) as f:
|
|
|
|
|
+ configs = yaml.safe_load(f)
|
|
|
|
|
+ with open(finetune_config) as f:
|
|
|
|
|
+ finetune_configs = yaml.safe_load(f)
|
|
|
|
|
+ # set data_types
|
|
|
|
|
+ if dataset_type == "large":
|
|
|
|
|
+ finetune_configs["dataset_conf"]["data_types"] = "sound,text"
|
|
|
|
|
+ finetune_configs = update_dct(configs, finetune_configs)
|
|
|
|
|
+ for key, value in finetune_configs.items():
|
|
|
|
|
+ if hasattr(args, key):
|
|
|
|
|
+ setattr(args, key, value)
|
|
|
|
|
+
|
|
|
|
|
+ # prepare data
|
|
|
|
|
+ args.dataset_type = dataset_type
|
|
|
|
|
+ if args.dataset_type == "small":
|
|
|
|
|
+ args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
|
|
|
|
|
+ ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
|
|
|
|
|
+ args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
|
|
|
|
|
+ ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
|
|
|
|
|
+ elif args.dataset_type == "large":
|
|
|
|
|
+ args.train_data_file = None
|
|
|
|
|
+ args.valid_data_file = None
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"Not supported dataset_type={args.dataset_type}")
|
|
|
|
|
+ args.init_param = [init_param]
|
|
|
|
|
+ args.cmvn_file = cmvn_file
|
|
|
|
|
+ if os.path.exists(seg_dict_file):
|
|
|
|
|
+ args.seg_dict_file = seg_dict_file
|
|
|
|
|
+ else:
|
|
|
|
|
+ args.seg_dict_file = None
|
|
|
|
|
+ args.data_dir = data_dir
|
|
|
|
|
+ args.train_set = train_set
|
|
|
|
|
+ args.dev_set = dev_set
|
|
|
|
|
+ args.output_dir = output_dir
|
|
|
|
|
+ args.gpu_id = args.local_rank
|
|
|
|
|
+ args.config = finetune_config
|
|
|
|
|
+ if optim is not None:
|
|
|
|
|
+ args.optim = optim
|
|
|
|
|
+ if lr is not None:
|
|
|
|
|
+ args.optim_conf["lr"] = lr
|
|
|
|
|
+ if scheduler is not None:
|
|
|
|
|
+ args.scheduler = scheduler
|
|
|
|
|
+ if scheduler_conf is not None:
|
|
|
|
|
+ args.scheduler_conf = scheduler_conf
|
|
|
|
|
+ if specaug is not None:
|
|
|
|
|
+ args.specaug = specaug
|
|
|
|
|
+ if specaug_conf is not None:
|
|
|
|
|
+ args.specaug_conf = specaug_conf
|
|
|
|
|
+ if max_epoch is not None:
|
|
|
|
|
+ args.max_epoch = max_epoch
|
|
|
|
|
+ if batch_bins is not None:
|
|
|
|
|
+ if args.dataset_type == "small":
|
|
|
|
|
+ args.batch_bins = batch_bins
|
|
|
|
|
+ elif args.dataset_type == "large":
|
|
|
|
|
+ args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"Not supported dataset_type={args.dataset_type}")
|
|
|
|
|
+ if args.normalize in ["null", "none", "None"]:
|
|
|
|
|
+ args.normalize = None
|
|
|
|
|
+ if args.patience in ["null", "none", "None"]:
|
|
|
|
|
+ args.patience = None
|
|
|
|
|
+ args.local_rank = local_rank
|
|
|
|
|
+ args.distributed = distributed
|
|
|
|
|
+ ASRTask.finetune_args = args
|
|
|
|
|
+
|
|
|
|
|
+ return ASRTask
|