sa_asr_train.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # -*- encoding: utf-8 -*-
  2. #!/usr/bin/env python3
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import os
  6. from funasr.tasks.sa_asr import ASRTask
  7. # for ASR Training
  8. def parse_args():
  9. parser = ASRTask.get_parser()
  10. parser.add_argument(
  11. "--gpu_id",
  12. type=int,
  13. default=0,
  14. help="local gpu id.",
  15. )
  16. args = parser.parse_args()
  17. return args
  18. def main(args=None, cmd=None):
  19. # for ASR Training
  20. ASRTask.main(args=args, cmd=cmd)
  21. if __name__ == '__main__':
  22. args = parse_args()
  23. # setup local gpu_id
  24. if args.ngpu > 0:
  25. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  26. # DDP settings
  27. if args.ngpu > 1:
  28. args.distributed = True
  29. else:
  30. args.distributed = False
  31. assert args.num_worker_count == 1
  32. # re-compute batch size: when dataset type is small
  33. if args.dataset_type == "small":
  34. if args.batch_size is not None and args.ngpu > 0:
  35. args.batch_size = args.batch_size * args.ngpu
  36. if args.batch_bins is not None and args.ngpu > 0:
  37. args.batch_bins = args.batch_bins * args.ngpu
  38. main(args=args)