grpc_server.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from concurrent import futures
  2. import grpc
  3. import json
  4. import time
  5. import paraformer_pb2_grpc
  6. from paraformer_pb2 import Response
  7. class ASRServicer(paraformer_pb2_grpc.ASRServicer):
  8. def __init__(self, user_allowed, model, sample_rate, backend, onnx_dir, vad_model='', punc_model=''):
  9. print("ASRServicer init")
  10. self.backend = backend
  11. self.init_flag = 0
  12. self.client_buffers = {}
  13. self.client_transcription = {}
  14. self.auth_user = user_allowed.split("|")
  15. if self.backend == "pipeline":
  16. try:
  17. from modelscope.pipelines import pipeline
  18. from modelscope.utils.constant import Tasks
  19. except ImportError:
  20. raise ImportError(f"Please install modelscope")
  21. self.inference_16k_pipeline = pipeline(task=Tasks.auto_speech_recognition, model=model, vad_model=vad_model, punc_model=punc_model)
  22. elif self.backend == "onnxruntime":
  23. try:
  24. from funasr_onnx import Paraformer
  25. except ImportError:
  26. raise ImportError(f"Please install onnxruntime environment")
  27. self.inference_16k_pipeline = Paraformer(model_dir=onnx_dir)
  28. self.sample_rate = sample_rate
  29. def clear_states(self, user):
  30. self.clear_buffers(user)
  31. self.clear_transcriptions(user)
  32. def clear_buffers(self, user):
  33. if user in self.client_buffers:
  34. del self.client_buffers[user]
  35. def clear_transcriptions(self, user):
  36. if user in self.client_transcription:
  37. del self.client_transcription[user]
  38. def disconnect(self, user):
  39. self.clear_states(user)
  40. print("Disconnecting user: %s" % str(user))
  41. def Recognize(self, request_iterator, context):
  42. for req in request_iterator:
  43. if req.user not in self.auth_user:
  44. result = {}
  45. result["success"] = False
  46. result["detail"] = "Not Authorized user: %s " % req.user
  47. result["text"] = ""
  48. yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)
  49. elif req.isEnd: #end grpc
  50. print("asr end")
  51. self.disconnect(req.user)
  52. result = {}
  53. result["success"] = True
  54. result["detail"] = "asr end"
  55. result["text"] = ""
  56. yield Response(sentence=json.dumps(result), user=req.user, action="terminate",language=req.language)
  57. elif req.speaking: #continue speaking
  58. if req.audio_data is not None and len(req.audio_data) > 0:
  59. if req.user in self.client_buffers:
  60. self.client_buffers[req.user] += req.audio_data #append audio
  61. else:
  62. self.client_buffers[req.user] = req.audio_data
  63. result = {}
  64. result["success"] = True
  65. result["detail"] = "speaking"
  66. result["text"] = ""
  67. yield Response(sentence=json.dumps(result), user=req.user, action="speaking", language=req.language)
  68. elif not req.speaking: #silence
  69. if req.user not in self.client_buffers:
  70. result = {}
  71. result["success"] = True
  72. result["detail"] = "waiting_for_more_voice"
  73. result["text"] = ""
  74. yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
  75. else:
  76. begin_time = int(round(time.time() * 1000))
  77. tmp_data = self.client_buffers[req.user]
  78. self.clear_states(req.user)
  79. result = {}
  80. result["success"] = True
  81. result["detail"] = "decoding data: %d bytes" % len(tmp_data)
  82. result["text"] = ""
  83. yield Response(sentence=json.dumps(result), user=req.user, action="decoding", language=req.language)
  84. if len(tmp_data) < 9600: #min input_len for asr model , 300ms
  85. end_time = int(round(time.time() * 1000))
  86. delay_str = str(end_time - begin_time)
  87. result = {}
  88. result["success"] = True
  89. result["detail"] = "waiting_for_more_voice"
  90. result["server_delay_ms"] = delay_str
  91. result["text"] = ""
  92. print ("user: %s , delay(ms): %s, info: %s " % (req.user, delay_str, "waiting_for_more_voice"))
  93. yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
  94. else:
  95. if self.backend == "pipeline":
  96. asr_result = self.inference_16k_pipeline(audio_in=tmp_data, audio_fs = self.sample_rate)
  97. if "text" in asr_result:
  98. asr_result = asr_result['text']
  99. else:
  100. asr_result = ""
  101. elif self.backend == "onnxruntime":
  102. from rapid_paraformer.utils.frontend import load_bytes
  103. array = load_bytes(tmp_data)
  104. asr_result = self.inference_16k_pipeline(array)[0]
  105. end_time = int(round(time.time() * 1000))
  106. delay_str = str(end_time - begin_time)
  107. print ("user: %s , delay(ms): %s, text: %s " % (req.user, delay_str, asr_result))
  108. result = {}
  109. result["success"] = True
  110. result["detail"] = "finish_sentence"
  111. result["server_delay_ms"] = delay_str
  112. result["text"] = asr_result
  113. yield Response(sentence=json.dumps(result), user=req.user, action="finish", language=req.language)
  114. else:
  115. result = {}
  116. result["success"] = False
  117. result["detail"] = "error, no condition matched! Unknown reason."
  118. result["text"] = ""
  119. self.disconnect(req.user)
  120. yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)