| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- # -*- encoding: utf-8 -*-
- #!/usr/bin/env python3
- # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
- # MIT License (https://opensource.org/licenses/MIT)
- import os
- from funasr.tasks.asr import ASRTask
- # for ASR Training
- def parse_args():
- parser = ASRTask.get_parser()
- parser.add_argument(
- "--mode",
- type=str,
- default="asr",
- help=" ",
- )
- parser.add_argument(
- "--gpu_id",
- type=int,
- default=0,
- help="local gpu id.",
- )
- args = parser.parse_args()
- return args
- def main(args=None, cmd=None):
-
- # for ASR Training
- if args.mode == "asr":
- from funasr.tasks.asr import ASRTask
- if args.mode == "paraformer":
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- if args.mode == "uniasr":
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
- if args.mode == "rnnt":
- from funasr.tasks.asr import ASRTransducerTask as ASRTask
- ASRTask.main(args=args, cmd=cmd)
- if __name__ == '__main__':
- args = parse_args()
- # setup local gpu_id
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
- # DDP settings
- if args.ngpu > 1:
- args.distributed = True
- else:
- args.distributed = False
- assert args.num_worker_count == 1
- # re-compute batch size: when dataset type is small
- if args.dataset_type == "small":
- if args.batch_size is not None:
- args.batch_size = args.batch_size * args.ngpu
- if args.batch_bins is not None:
- args.batch_bins = args.batch_bins * args.ngpu
- main(args=args)
|