|
|
@@ -50,7 +50,7 @@ def main(kwargs: DictConfig):
|
|
|
use_fsdp = kwargs.get("use_fsdp", None)
|
|
|
if use_ddp or use_fsdp:
|
|
|
dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
|
|
|
- device= torch.cuda.set_device(local_rank)
|
|
|
+ torch.cuda.set_device(local_rank)
|
|
|
|
|
|
|
|
|
# build_tokenizer
|
|
|
@@ -72,9 +72,24 @@ def main(kwargs: DictConfig):
|
|
|
# model_class = load_class_from_path(kwargs.get("model").split(":"))
|
|
|
model_class = dynamic_import(kwargs.get("model"))
|
|
|
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
|
|
|
- # model = model.to(device=kwargs.get("device", "cpu"))
|
|
|
-
|
|
|
-
|
|
|
+ frontend = model.frontend
|
|
|
+ # init_param
|
|
|
+ init_param = kwargs.get("init_param", None)
|
|
|
+ if init_param is not None:
|
|
|
+ init_param = eval(init_param)
|
|
|
+ if isinstance(init_param, Sequence):
|
|
|
+ init_param = (init_param,)
|
|
|
+ logging.info("init_param is not None: ", init_param)
|
|
|
+ for p in init_param:
|
|
|
+ logging.info(f"Loading pretrained params from {p}")
|
|
|
+ load_pretrained_model(
|
|
|
+ model=model,
|
|
|
+ init_param=p,
|
|
|
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
|
|
+ oss_bucket=kwargs.get("oss_bucket", None),
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ initialize(model, kwargs.get("init", "kaiming_normal"))
|
|
|
|
|
|
# import pdb;
|
|
|
# pdb.set_trace()
|
|
|
@@ -97,6 +112,8 @@ def main(kwargs: DictConfig):
|
|
|
model = DDP(model, device_ids=[local_rank])
|
|
|
elif use_fsdp:
|
|
|
model = FSDP(model).cuda(local_rank)
|
|
|
+ else:
|
|
|
+ model = model.to(device=kwargs.get("device", "cuda"))
|
|
|
|
|
|
|
|
|
# optim
|
|
|
@@ -111,27 +128,9 @@ def main(kwargs: DictConfig):
|
|
|
scheduler_class = scheduler_choices.get(scheduler)
|
|
|
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
|
|
|
|
|
|
- # init_param
|
|
|
- init_param = kwargs.get("init_param", None)
|
|
|
- if init_param is not None:
|
|
|
- init_param = eval(init_param)
|
|
|
- if isinstance(init_param, Sequence):
|
|
|
- init_param = (init_param,)
|
|
|
- logging.info("init_param is not None: ", freeze_param)
|
|
|
- for p in init_param:
|
|
|
- logging.info(f"Loading pretrained params from {p}")
|
|
|
- load_pretrained_model(
|
|
|
- model=model,
|
|
|
- init_param=p,
|
|
|
- ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
|
|
|
- oss_bucket=kwargs.get("oss_bucket", None),
|
|
|
- )
|
|
|
- else:
|
|
|
- initialize(model, kwargs.get("init", "kaiming_normal"))
|
|
|
-
|
|
|
|
|
|
# dataset
|
|
|
- dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=model.frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
|
|
+ dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
|
|
|
|
|
|
# dataloader
|
|
|
batch_sampler = BatchSampler(dataset_tr, **kwargs.get("dataset_conf"), **kwargs.get("dataset_conf").get("batch_conf"))
|