__init__.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. """Initialize funasr package."""
  2. import os
  3. from pathlib import Path
  4. import torch
  5. import numpy as np
  6. dirname = os.path.dirname(__file__)
  7. version_file = os.path.join(dirname, "version.txt")
  8. with open(version_file, "r") as f:
  9. __version__ = f.read().strip()
  10. def prepare_model(
  11. model: str = None,
  12. # mode: str = None,
  13. vad_model: str = None,
  14. punc_model: str = None,
  15. model_hub: str = "ms",
  16. cache_dir: str = None,
  17. **kwargs,
  18. ):
  19. if not Path(model).exists():
  20. if model_hub == "ms" or model_hub == "modelscope":
  21. try:
  22. from modelscope.hub.snapshot_download import snapshot_download as download_tool
  23. model = name_maps_ms[model] if model is not None else None
  24. vad_model = name_maps_ms[vad_model] if vad_model is not None else None
  25. punc_model = name_maps_ms[punc_model] if punc_model is not None else None
  26. except:
  27. raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
  28. "\npip3 install -U modelscope\n" \
  29. "For the users in China, you could install with the command:\n" \
  30. "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
  31. elif model_hub == "hf" or model_hub == "huggingface":
  32. download_tool = 0
  33. else:
  34. raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
  35. try:
  36. model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
  37. print("model have been downloaded to: {}".format(model))
  38. except:
  39. raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
  40. model)
  41. if vad_model is not None and not Path(vad_model).exists():
  42. vad_model = download_tool(vad_model, cache_dir=cache_dir)
  43. print("model have been downloaded to: {}".format(vad_model))
  44. if punc_model is not None and not Path(punc_model).exists():
  45. punc_model = download_tool(punc_model, cache_dir=cache_dir)
  46. print("model have been downloaded to: {}".format(punc_model))
  47. # asr
  48. kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
  49. "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
  50. "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
  51. })
  52. mode = kwargs.get("mode", None)
  53. if mode is None:
  54. import json
  55. json_file = os.path.join(model, 'configuration.json')
  56. with open(json_file, 'r') as f:
  57. config_data = json.load(f)
  58. if config_data['task'] == "punctuation":
  59. mode = config_data['model']['punc_model_config']['mode']
  60. else:
  61. mode = config_data['model']['model_config']['mode']
  62. if vad_model is not None and "vad" not in mode:
  63. mode = "paraformer_vad"
  64. kwargs["mode"] = mode
  65. # vad
  66. kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
  67. "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
  68. "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
  69. })
  70. # punc
  71. kwargs.update({
  72. "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
  73. "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
  74. })
  75. return model, vad_model, punc_model, kwargs
  76. name_maps_ms = {
  77. "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
  78. "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
  79. "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
  80. "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
  81. "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
  82. "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
  83. "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
  84. "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
  85. }
  86. def infer(task_name: str = "asr",
  87. model: str = None,
  88. # mode: str = None,
  89. vad_model: str = None,
  90. punc_model: str = None,
  91. model_hub: str = "ms",
  92. cache_dir: str = None,
  93. **kwargs,
  94. ):
  95. model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
  96. if task_name == "asr":
  97. from funasr.bin.asr_inference_launch import inference_launch
  98. inference_pipeline = inference_launch(**kwargs)
  99. elif task_name == "":
  100. pipeline = 1
  101. elif task_name == "":
  102. pipeline = 2
  103. elif task_name == "":
  104. pipeline = 2
  105. def _infer_fn(input, **kwargs):
  106. data_type = kwargs.get('data_type', 'sound')
  107. data_path_and_name_and_type = [input, 'speech', data_type]
  108. raw_inputs = None
  109. if isinstance(input, torch.Tensor):
  110. input = input.numpy()
  111. if isinstance(input, np.ndarray):
  112. data_path_and_name_and_type = None
  113. raw_inputs = input
  114. return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
  115. return _infer_fn
  116. if __name__ == '__main__':
  117. pass