modelscope_infer.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. #!/usr/bin/env python3
  2. import argparse
  3. import logging
  4. import os
  5. from modelscope.pipelines import pipeline
  6. from modelscope.utils.constant import Tasks
  7. if __name__ == '__main__':
  8. parser = argparse.ArgumentParser(
  9. description="decoding configs",
  10. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  11. )
  12. parser.add_argument("--model_name",
  13. type=str,
  14. default="speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
  15. help="model name in modelscope")
  16. parser.add_argument("--local_model_path",
  17. type=str,
  18. default=None,
  19. help="local model path, usually for fine-tuning")
  20. parser.add_argument("--wav_list",
  21. type=str,
  22. help="input wav list")
  23. parser.add_argument("--output_file",
  24. type=str,
  25. help="saving decoding results")
  26. parser.add_argument(
  27. "--njob",
  28. type=int,
  29. default=1,
  30. help="The number of jobs for each gpu",
  31. )
  32. parser.add_argument(
  33. "--gpuid_list",
  34. type=str,
  35. default="",
  36. help="The visible gpus",
  37. )
  38. parser.add_argument(
  39. "--ngpu",
  40. type=int,
  41. default=0,
  42. help="The number of gpus. 0 indicates CPU mode",
  43. )
  44. args = parser.parse_args()
  45. # set logging messages
  46. logging.basicConfig(
  47. level=logging.INFO,
  48. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  49. )
  50. logging.info("Decoding args: {}".format(args))
  51. # gpu setting
  52. if args.ngpu > 0:
  53. jobid = int(args.output_file.split(".")[-1])
  54. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  55. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  56. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  57. if args.local_model_path is None:
  58. inference_pipeline = pipeline(
  59. task=Tasks.auto_speech_recognition,
  60. model="damo/{}".format(args.model_name))
  61. else:
  62. inference_pipeline = pipeline(
  63. task=Tasks.auto_speech_recognition,
  64. model=args.local_model_path)
  65. with open(args.wav_list, 'r') as f_wav:
  66. wav_lines = f_wav.readlines()
  67. with open(args.output_file, "w") as f_out:
  68. for line in wav_lines:
  69. wav_id, wav_path = line.strip().split()
  70. logging.info("decoding, utt_id: ['{}']".format(wav_id))
  71. rec_result = inference_pipeline(audio_in=wav_path)
  72. text = rec_result["text"]
  73. f_out.write(wav_id + " " + text + "\n")
  74. logging.info("best hypo: {} \n".format(text))