asr_train.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # -*- encoding: utf-8 -*-
  2. #!/usr/bin/env python3
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import os
  6. from funasr.tasks.asr import ASRTask
  7. # for ASR Training
  8. def parse_args():
  9. parser = ASRTask.get_parser()
  10. parser.add_argument(
  11. "--mode",
  12. type=str,
  13. default="asr",
  14. help=" ",
  15. )
  16. parser.add_argument(
  17. "--gpu_id",
  18. type=int,
  19. default=0,
  20. help="local gpu id.",
  21. )
  22. args = parser.parse_args()
  23. return args
  24. def main(args=None, cmd=None):
  25. # for ASR Training
  26. if args.mode == "asr":
  27. from funasr.tasks.asr import ASRTask
  28. if args.mode == "paraformer":
  29. from funasr.tasks.asr import ASRTaskParaformer as ASRTask
  30. if args.mode == "uniasr":
  31. from funasr.tasks.asr import ASRTaskUniASR as ASRTask
  32. if args.mode == "rnnt":
  33. from funasr.tasks.asr import ASRTransducerTask as ASRTask
  34. ASRTask.main(args=args, cmd=cmd)
  35. if __name__ == '__main__':
  36. args = parse_args()
  37. # setup local gpu_id
  38. if args.ngpu > 0:
  39. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  40. # DDP settings
  41. if args.ngpu > 1:
  42. args.distributed = True
  43. else:
  44. args.distributed = False
  45. assert args.num_worker_count == 1
  46. # re-compute batch size: when dataset type is small
  47. if args.dataset_type == "small":
  48. if args.batch_size is not None and args.ngpu > 0:
  49. args.batch_size = args.batch_size * args.ngpu
  50. if args.batch_bins is not None and args.ngpu > 0:
  51. args.batch_bins = args.batch_bins * args.ngpu
  52. main(args=args)