lm_train.py 1019 B

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. #!/usr/bin/env python3
  2. import os
  3. from funasr.tasks.lm import LMTask
  4. # for LM Training
  5. def parse_args():
  6. parser = LMTask.get_parser()
  7. parser.add_argument(
  8. "--gpu_id",
  9. type=int,
  10. default=0,
  11. help="local gpu id.",
  12. )
  13. args = parser.parse_args()
  14. return args
  15. def main(args=None, cmd=None):
  16. # for LM Training
  17. LMTask.main(args=args, cmd=cmd)
  18. if __name__ == '__main__':
  19. args = parse_args()
  20. # setup local gpu_id
  21. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  22. # DDP settings
  23. if args.ngpu > 1:
  24. args.distributed = True
  25. else:
  26. args.distributed = False
  27. assert args.num_worker_count == 1
  28. # re-compute batch size: when dataset type is small
  29. if args.dataset_type == "small" and args.ngpu != 0:
  30. if args.batch_size is not None:
  31. args.batch_size = args.batch_size * args.ngpu
  32. if args.batch_bins is not None:
  33. args.batch_bins = args.batch_bins * args.ngpu
  34. main(args=args)