punc_train_vadrealtime.py 866 B

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #!/usr/bin/env python3
  2. import os
  3. from funasr.tasks.punctuation import PunctuationTask
  4. def parse_args():
  5. parser = PunctuationTask.get_parser()
  6. parser.add_argument(
  7. "--gpu_id",
  8. type=int,
  9. default=0,
  10. help="local gpu id.",
  11. )
  12. parser.add_argument(
  13. "--punc_list",
  14. type=str,
  15. default=None,
  16. help="Punctuation list",
  17. )
  18. args = parser.parse_args()
  19. return args
  20. def main(args=None, cmd=None):
  21. """
  22. punc training.
  23. """
  24. PunctuationTask.main(args=args, cmd=cmd)
  25. if __name__ == "__main__":
  26. args = parse_args()
  27. # setup local gpu_id
  28. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  29. # DDP settings
  30. if args.ngpu > 1:
  31. args.distributed = True
  32. else:
  33. args.distributed = False
  34. assert args.num_worker_count == 1
  35. main(args=args)