lm_train.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # -*- encoding: utf-8 -*-
  2. #!/usr/bin/env python3
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import os
  6. from funasr.tasks.lm import LMTask
  7. # for LM Training
  8. def parse_args():
  9. parser = LMTask.get_parser()
  10. parser.add_argument(
  11. "--gpu_id",
  12. type=int,
  13. default=0,
  14. help="local gpu id.",
  15. )
  16. args = parser.parse_args()
  17. return args
  18. def main(args=None, cmd=None):
  19. # for LM Training
  20. LMTask.main(args=args, cmd=cmd)
  21. if __name__ == '__main__':
  22. args = parse_args()
  23. # setup local gpu_id
  24. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  25. # DDP settings
  26. if args.ngpu > 1:
  27. args.distributed = True
  28. else:
  29. args.distributed = False
  30. assert args.num_worker_count == 1
  31. # re-compute batch size: when dataset type is small
  32. if args.dataset_type == "small" and args.ngpu != 0:
  33. if args.batch_size is not None:
  34. args.batch_size = args.batch_size * args.ngpu
  35. if args.batch_bins is not None:
  36. args.batch_bins = args.batch_bins * args.ngpu
  37. main(args=args)