gen_modelscope_configuration.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import argparse
  2. import json
  3. import os
  4. if __name__ == '__main__':
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument(
  7. "--task",
  8. type=str,
  9. default="auto-speech-recognition",
  10. help="task name",
  11. )
  12. parser.add_argument(
  13. "--type",
  14. type=str,
  15. default="generic-asr",
  16. )
  17. parser.add_argument(
  18. "--am_model_name",
  19. type=str,
  20. default="model.pb",
  21. help="model file name",
  22. )
  23. parser.add_argument(
  24. "--mode",
  25. type=str,
  26. default="paraformer",
  27. help="mode for decoding",
  28. )
  29. parser.add_argument(
  30. "--lang",
  31. type=str,
  32. default="zh-cn",
  33. help="language",
  34. )
  35. parser.add_argument(
  36. "--batch_size",
  37. type=int,
  38. default=1,
  39. help="batch size",
  40. )
  41. parser.add_argument(
  42. "--am_model_config",
  43. type=str,
  44. default="config.yaml",
  45. help="config file",
  46. )
  47. parser.add_argument(
  48. "--mvn_file",
  49. type=str,
  50. default="am.mvn",
  51. help="cmvn file",
  52. )
  53. parser.add_argument(
  54. "--model_name",
  55. type=str,
  56. help="model name",
  57. )
  58. parser.add_argument(
  59. "--pipeline_type",
  60. type=str,
  61. default="asr-inference",
  62. help="pipeline type",
  63. )
  64. parser.add_argument(
  65. "--vocab_size",
  66. type=int,
  67. help="vocab_size",
  68. )
  69. parser.add_argument(
  70. "--dataset",
  71. type=str,
  72. help="dataset name",
  73. )
  74. parser.add_argument(
  75. "--output_dir",
  76. type=str,
  77. help="output path",
  78. )
  79. parser.add_argument(
  80. "--nat",
  81. type=str,
  82. default="",
  83. help="nat",
  84. )
  85. parser.add_argument(
  86. "--tag",
  87. type=str,
  88. default="exp1",
  89. help="model name tag",
  90. )
  91. args = parser.parse_args()
  92. model = {
  93. "type": args.type,
  94. "am_model_name": args.am_model_name,
  95. "model_config": {
  96. "type": "pytorch",
  97. "code_base": "funasr",
  98. "mode": args.mode,
  99. "lang": args.lang,
  100. "batch_size": args.batch_size,
  101. "am_model_config": args.am_model_config,
  102. "mvn_file": args.mvn_file,
  103. "model": "speech_{}_asr{}-{}-16k-{}-vocab{}-pytorch-{}".format(args.model_name, args.nat, args.lang,
  104. args.dataset, args.vocab_size, args.tag),
  105. }
  106. }
  107. pipeline = {"type": args.pipeline_type}
  108. json_dict = {
  109. "framework": "pytorch",
  110. "task": args.task,
  111. "model": model,
  112. "pipeline": pipeline,
  113. }
  114. with open(os.path.join(args.output_dir, "configuration.json"), "w") as f:
  115. json.dump(json_dict, f, indent=4)