train.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. import os
  4. import sys
  5. import torch
  6. import hydra
  7. import logging
  8. import argparse
  9. from io import BytesIO
  10. import torch.distributed as dist
  11. from collections.abc import Sequence
  12. from omegaconf import DictConfig, OmegaConf
  13. from torch.nn.parallel import DistributedDataParallel as DDP
  14. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  15. from funasr.register import tables
  16. from funasr.optimizers import optim_classes
  17. from funasr.train_utils.trainer import Trainer
  18. from funasr.schedulers import scheduler_classes
  19. from funasr.train_utils.initialize import initialize
  20. from funasr.download.download_from_hub import download_model
  21. from funasr.models.lora.utils import mark_only_lora_as_trainable
  22. from funasr.train_utils.set_all_random_seed import set_all_random_seed
  23. from funasr.train_utils.load_pretrained_model import load_pretrained_model
  24. # from funasr.tokenizer.build_tokenizer import build_tokenizer
  25. # from funasr.tokenizer.token_id_converter import TokenIDConverter
  26. # from funasr.tokenizer.funtoken import build_tokenizer
  27. @hydra.main(config_name=None, version_base=None)
  28. def main_hydra(kwargs: DictConfig):
  29. if kwargs.get("debug", False):
  30. import pdb; pdb.set_trace()
  31. assert "model" in kwargs
  32. if "model_conf" not in kwargs:
  33. logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
  34. kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
  35. main(**kwargs)
  36. def main(**kwargs):
  37. print(kwargs)
  38. # set random seed
  39. set_all_random_seed(kwargs.get("seed", 0))
  40. torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
  41. torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
  42. torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True)
  43. local_rank = int(os.environ.get('LOCAL_RANK', 0))
  44. if local_rank == 0:
  45. tables.print()
  46. # Check if we are using DDP or FSDP
  47. use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1
  48. use_fsdp = kwargs.get("use_fsdp", None)
  49. if use_ddp or use_fsdp:
  50. dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
  51. torch.cuda.set_device(local_rank)
  52. # save config.yaml
  53. if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
  54. os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
  55. yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
  56. OmegaConf.save(config=kwargs, f=yaml_file)
  57. logging.info("config.yaml is saved to: %s", yaml_file)
  58. tokenizer = kwargs.get("tokenizer", None)
  59. if tokenizer is not None:
  60. tokenizer_class = tables.tokenizer_classes.get(tokenizer)
  61. tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
  62. kwargs["tokenizer"] = tokenizer
  63. # build frontend if frontend is none None
  64. frontend = kwargs.get("frontend", None)
  65. if frontend is not None:
  66. frontend_class = tables.frontend_classes.get(frontend)
  67. frontend = frontend_class(**kwargs["frontend_conf"])
  68. kwargs["frontend"] = frontend
  69. kwargs["input_size"] = frontend.output_size()
  70. # build model
  71. model_class = tables.model_classes.get(kwargs["model"])
  72. vocab_size = len(tokenizer.token_list) if hasattr(tokenizer, "token_list") else None
  73. vocab_size = len(tokenizer.get_vocab()) if hasattr(tokenizer, "get_vocab") else vocab_size
  74. model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
  75. # init_param
  76. init_param = kwargs.get("init_param", None)
  77. if init_param is not None:
  78. if not isinstance(init_param, (list, tuple)):
  79. init_param = (init_param,)
  80. logging.info("init_param is not None: %s", init_param)
  81. for p in init_param:
  82. if os.path.exists(p):
  83. logging.info(f"Loading pretrained params from {p}")
  84. load_pretrained_model(
  85. model=model,
  86. path=p,
  87. ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
  88. oss_bucket=kwargs.get("oss_bucket", None),
  89. scope_map=kwargs.get("scope_map", []),
  90. excludes=kwargs.get("excludes", None),
  91. )
  92. else:
  93. logging.info(f"Checkpoint does not exist, init randomly: {p}")
  94. elif kwargs.get("init", None):
  95. initialize(model, kwargs.get("init", "kaiming_normal"))
  96. else:
  97. print("No initialize method")
  98. # freeze_param
  99. freeze_param = kwargs.get("freeze_param", None)
  100. if freeze_param is not None:
  101. freeze_param = eval(freeze_param)
  102. if isinstance(freeze_param, Sequence):
  103. freeze_param = (freeze_param,)
  104. logging.info("freeze_param is not None: %s", freeze_param)
  105. for t in freeze_param:
  106. for k, p in model.named_parameters():
  107. if k.startswith(t + ".") or k == t:
  108. logging.info(f"Setting {k}.requires_grad = False")
  109. p.requires_grad = False
  110. if use_ddp:
  111. model = model.cuda(local_rank)
  112. model = DDP(model, device_ids=[local_rank],
  113. find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False))
  114. elif use_fsdp:
  115. model = FSDP(model).cuda(local_rank)
  116. else:
  117. model = model.to(device=kwargs.get("device", "cuda"))
  118. # optim
  119. optim = kwargs.get("optim", "adam")
  120. assert optim in optim_classes
  121. optim_class = optim_classes.get(optim)
  122. optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
  123. # scheduler
  124. scheduler = kwargs.get("scheduler", "warmuplr")
  125. assert scheduler in scheduler_classes
  126. scheduler_class = scheduler_classes.get(scheduler)
  127. scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
  128. # dataset
  129. dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
  130. dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=True, **kwargs.get("dataset_conf"))
  131. dataset_val = dataset_class(kwargs.get("valid_data_set_list"), frontend=frontend, tokenizer=tokenizer, is_training=False, **kwargs.get("dataset_conf"))
  132. # dataloader
  133. batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
  134. batch_sampler_val = None
  135. if batch_sampler is not None:
  136. batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler)
  137. batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
  138. batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf"))
  139. dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
  140. collate_fn=dataset_tr.collator,
  141. batch_sampler=batch_sampler,
  142. num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
  143. pin_memory=True)
  144. dataloader_val = torch.utils.data.DataLoader(dataset_val,
  145. collate_fn=dataset_val.collator,
  146. batch_sampler=batch_sampler_val,
  147. num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
  148. pin_memory=True)
  149. trainer = Trainer(
  150. model=model,
  151. optim=optim,
  152. scheduler=scheduler,
  153. dataloader_train=dataloader_tr,
  154. dataloader_val=dataloader_val,
  155. local_rank=local_rank,
  156. use_ddp=use_ddp,
  157. use_fsdp=use_fsdp,
  158. output_dir=kwargs.get("output_dir", "./exp"),
  159. resume=kwargs.get("resume", True),
  160. **kwargs.get("train_conf"),
  161. )
  162. trainer.run()
  163. if use_ddp or use_fsdp:
  164. torch.distributed.destroy_process_group()
  165. if __name__ == "__main__":
  166. main_hydra()