train_cli.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import argparse
  2. import logging
  3. import os
  4. import sys
  5. from io import BytesIO
  6. from collections.abc import Sequence
  7. import torch
  8. import hydra
  9. from omegaconf import DictConfig, OmegaConf
  10. from funasr.torch_utils.set_all_random_seed import set_all_random_seed
  11. # from funasr.model_class_factory1 import model_choices
  12. from funasr.modules.lora.utils import mark_only_lora_as_trainable
  13. from funasr.optimizers import optim_choices
  14. from funasr.schedulers import scheduler_choices
  15. from funasr.torch_utils.load_pretrained_model import load_pretrained_model
  16. from funasr.torch_utils.initialize import initialize
  17. from funasr.datasets.data_sampler import BatchSampler
  18. # from funasr.tokenizer.build_tokenizer import build_tokenizer
  19. # from funasr.tokenizer.token_id_converter import TokenIDConverter
  20. from funasr.tokenizer.funtoken import build_tokenizer
  21. from funasr.datasets.dataset_jsonl import AudioDataset
  22. from funasr.cli.trainer import Trainer
  23. # from funasr.utils.load_fr_py import load_class_from_path
  24. from funasr.utils.dynamic_import import dynamic_import
  25. import torch.distributed as dist
  26. from torch.nn.parallel import DistributedDataParallel as DDP
  27. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  28. def preprocess_config(cfg: DictConfig):
  29. for key, value in cfg.items():
  30. if value == 'None':
  31. cfg[key] = None
  32. @hydra.main()
  33. def main(kwargs: DictConfig):
  34. # preprocess_config(kwargs)
  35. import pdb; pdb.set_trace()
  36. # set random seed
  37. set_all_random_seed(kwargs.get("seed", 0))
  38. torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
  39. torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
  40. torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
  41. local_rank = int(os.environ.get('LOCAL_RANK', 0))
  42. # Check if we are using DDP or FSDP
  43. use_ddp = 'WORLD_SIZE' in os.environ
  44. use_fsdp = kwargs.get("use_fsdp", None)
  45. if use_ddp or use_fsdp:
  46. dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
  47. torch.cuda.set_device(local_rank)
  48. # build_tokenizer
  49. tokenizer = build_tokenizer(
  50. token_type=kwargs.get("token_type", "char"),
  51. bpemodel=kwargs.get("bpemodel", None),
  52. delimiter=kwargs.get("delimiter", None),
  53. space_symbol=kwargs.get("space_symbol", "<space>"),
  54. non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
  55. g2p_type=kwargs.get("g2p_type", None),
  56. token_list=kwargs.get("token_list", None),
  57. unk_symbol=kwargs.get("unk_symbol", "<unk>"),
  58. )
  59. # import pdb;
  60. # pdb.set_trace()
  61. # build model
  62. # model_class = model_choices.get_class(kwargs.get("model", "asr"))
  63. # model_class = load_class_from_path(kwargs.get("model").split(":"))
  64. model_class = dynamic_import(kwargs.get("model"))
  65. model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
  66. frontend = model.frontend
  67. # init_param
  68. init_param = kwargs.get("init_param", None)
  69. if init_param is not None:
  70. init_param = eval(init_param)
  71. if isinstance(init_param, Sequence):
  72. init_param = (init_param,)
  73. logging.info("init_param is not None: ", init_param)
  74. for p in init_param:
  75. logging.info(f"Loading pretrained params from {p}")
  76. load_pretrained_model(
  77. model=model,
  78. init_param=p,
  79. ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
  80. oss_bucket=kwargs.get("oss_bucket", None),
  81. )
  82. else:
  83. initialize(model, kwargs.get("init", "kaiming_normal"))
  84. # import pdb;
  85. # pdb.set_trace()
  86. # freeze_param
  87. freeze_param = kwargs.get("freeze_param", None)
  88. if freeze_param is not None:
  89. freeze_param = eval(freeze_param)
  90. if isinstance(freeze_param, Sequence):
  91. freeze_param = (freeze_param,)
  92. logging.info("freeze_param is not None: ", freeze_param)
  93. for t in freeze_param:
  94. for k, p in model.named_parameters():
  95. if k.startswith(t + ".") or k == t:
  96. logging.info(f"Setting {k}.requires_grad = False")
  97. p.requires_grad = False
  98. if use_ddp:
  99. model = model.cuda(local_rank)
  100. model = DDP(model, device_ids=[local_rank])
  101. elif use_fsdp:
  102. model = FSDP(model).cuda(local_rank)
  103. else:
  104. model = model.to(device=kwargs.get("device", "cuda"))
  105. # optim
  106. optim = kwargs.get("optim", "adam")
  107. assert optim in optim_choices
  108. optim_class = optim_choices.get(optim)
  109. optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
  110. # scheduler
  111. scheduler = kwargs.get("scheduler", "warmuplr")
  112. assert scheduler in scheduler_choices
  113. scheduler_class = scheduler_choices.get(scheduler)
  114. scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
  115. # dataset
  116. dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
  117. # dataloader
  118. batch_sampler = BatchSampler(dataset_tr, **kwargs.get("dataset_conf"), **kwargs.get("dataset_conf").get("batch_conf"))
  119. dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
  120. collate_fn=dataset_tr.collator,
  121. batch_sampler=batch_sampler,
  122. num_workers=kwargs.get("num_workers", 0),
  123. pin_memory=True)
  124. trainer = Trainer(
  125. model=model,
  126. optim=optim,
  127. scheduler=scheduler,
  128. dataloader_train=dataloader_tr,
  129. dataloader_val=None,
  130. local_rank=local_rank,
  131. use_ddp=use_ddp,
  132. use_fsdp=use_fsdp,
  133. **kwargs.get("train_conf"),
  134. )
  135. trainer.run()
  136. if use_ddp or use_fsdp:
  137. torch.distributed.destroy_process_group()
  138. def train(epoch, model, op):
  139. pass
  140. def val():
  141. pass
  142. if __name__ == "__main__":
  143. main()