lm_inference_launch.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #!/usr/bin/env python3
  2. # Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. import torch
  5. torch.set_num_threads(1)
  6. import argparse
  7. import logging
  8. import os
  9. import sys
  10. from typing import Union, Dict, Any
  11. from funasr.utils import config_argparse
  12. from funasr.utils.cli_utils import get_commandline_args
  13. from funasr.utils.types import str2bool
  14. from funasr.utils.types import str2triple_str
  15. from funasr.utils.types import str_or_none
  16. from funasr.utils.types import float_or_none
  17. def get_parser():
  18. parser = config_argparse.ArgumentParser(
  19. description="Calc perplexity",
  20. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  21. )
  22. parser.add_argument(
  23. "--log_level",
  24. type=lambda x: x.upper(),
  25. default="INFO",
  26. choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
  27. help="The verbose level of logging",
  28. )
  29. parser.add_argument("--output_dir", type=str, required=True)
  30. parser.add_argument("--gpuid_list", type=str, required=True)
  31. parser.add_argument(
  32. "--ngpu",
  33. type=int,
  34. default=0,
  35. help="The number of gpus. 0 indicates CPU mode",
  36. )
  37. parser.add_argument("--seed", type=int, default=0, help="Random seed")
  38. parser.add_argument("--njob", type=int, default=1, help="Random seed")
  39. parser.add_argument(
  40. "--dtype",
  41. default="float32",
  42. choices=["float16", "float32", "float64"],
  43. help="Data type",
  44. )
  45. parser.add_argument(
  46. "--num_workers",
  47. type=int,
  48. default=1,
  49. help="The number of workers used for DataLoader",
  50. )
  51. parser.add_argument(
  52. "--batch_size",
  53. type=int,
  54. default=1,
  55. help="The batch size for inference",
  56. )
  57. parser.add_argument(
  58. "--log_base",
  59. type=float_or_none,
  60. default=10,
  61. help="The base of logarithm for Perplexity. "
  62. "If None, napier's constant is used.",
  63. required=False
  64. )
  65. group = parser.add_argument_group("Input data related")
  66. group.add_argument(
  67. "--data_path_and_name_and_type",
  68. type=str2triple_str,
  69. action="append",
  70. required=False
  71. )
  72. group.add_argument(
  73. "--raw_inputs",
  74. type=str,
  75. required=False
  76. )
  77. group.add_argument("--key_file", type=str_or_none)
  78. group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
  79. group.add_argument("--split_with_space", type=str2bool, default=False)
  80. group.add_argument("--seg_dict_file", type=str_or_none)
  81. group = parser.add_argument_group("The model configuration related")
  82. group.add_argument("--train_config", type=str)
  83. group.add_argument("--model_file", type=str)
  84. group.add_argument("--mode", type=str, default="lm")
  85. return parser
  86. def inference_launch(mode, **kwargs):
  87. if mode == "transformer":
  88. from funasr.bin.lm_inference import inference_modelscope
  89. return inference_modelscope(**kwargs)
  90. else:
  91. logging.info("Unknown decoding mode: {}".format(mode))
  92. return None
  93. def main(cmd=None):
  94. print(get_commandline_args(), file=sys.stderr)
  95. parser = get_parser()
  96. args = parser.parse_args(cmd)
  97. kwargs = vars(args)
  98. kwargs.pop("config", None)
  99. # set logging messages
  100. logging.basicConfig(
  101. level=args.log_level,
  102. format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
  103. )
  104. logging.info("Decoding args: {}".format(kwargs))
  105. # gpu setting
  106. if args.ngpu > 0:
  107. jobid = int(args.output_dir.split(".")[-1])
  108. gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
  109. os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
  110. os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
  111. kwargs.pop("gpuid_list", None)
  112. kwargs.pop("njob", None)
  113. results = inference_launch(**kwargs)
  114. if __name__ == "__main__":
  115. main()