data2vec_train.py 1012 B

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. #!/usr/bin/env python3
  2. import os
  3. from funasr.tasks.data2vec import Data2VecTask
  4. def parse_args():
  5. parser = Data2VecTask.get_parser()
  6. parser.add_argument(
  7. "--gpu_id",
  8. type=int,
  9. default=0,
  10. help="local gpu id.",
  11. )
  12. args = parser.parse_args()
  13. return args
  14. def main(args=None, cmd=None):
  15. # for data2vec Training
  16. Data2VecTask.main(args=args, cmd=cmd)
  17. if __name__ == '__main__':
  18. args = parse_args()
  19. # setup local gpu_id
  20. os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
  21. # DDP settings
  22. if args.ngpu > 1:
  23. args.distributed = True
  24. else:
  25. args.distributed = False
  26. assert args.num_worker_count == 1
  27. # re-compute batch size: when dataset type is small
  28. if args.dataset_type == "small":
  29. if args.batch_size is not None:
  30. args.batch_size = args.batch_size * args.ngpu
  31. if args.batch_bins is not None:
  32. args.batch_bins = args.batch_bins * args.ngpu
  33. main(args=args)