punc_train.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # -*- encoding: utf-8 -*-
  2. #!/usr/bin/env python3
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import os
  6. from funasr.tasks.punctuation import PunctuationTask
  7. def parse_args():
  8. parser = PunctuationTask.get_parser()
  9. parser.add_argument(
  10. "--gpu_id",
  11. type=int,
  12. default=0,
  13. help="local gpu id.",
  14. )
  15. parser.add_argument(
  16. "--punc_list",
  17. type=str,
  18. default=None,
  19. help="Punctuation list",
  20. )
  21. args = parser.parse_args()
  22. return args
  23. def main(args=None, cmd=None):
  24. """
  25. punc training.
  26. """
  27. PunctuationTask.main(args=args, cmd=cmd)
  28. if __name__ == "__main__":
  29. args = parse_args()
  30. # setup local gpu_id
  31. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  32. # DDP settings
  33. if args.ngpu > 1:
  34. args.distributed = True
  35. else:
  36. args.distributed = False
  37. if args.dataset_type == "small":
  38. if args.batch_size is not None:
  39. args.batch_size = args.batch_size * args.ngpu * args.num_worker_count
  40. if args.batch_bins is not None:
  41. args.batch_bins = args.batch_bins * args.ngpu * args.num_worker_count
  42. main(args=args)