client.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import multiprocessing
  15. from multiprocessing import Pool
  16. import argparse
  17. import os
  18. import tritonclient.grpc as grpcclient
  19. from utils import cal_cer
  20. from speech_client import *
  21. import numpy as np
  22. if __name__ == "__main__":
  23. parser = argparse.ArgumentParser()
  24. parser.add_argument(
  25. "-v",
  26. "--verbose",
  27. action="store_true",
  28. required=False,
  29. default=False,
  30. help="Enable verbose output",
  31. )
  32. parser.add_argument(
  33. "-u",
  34. "--url",
  35. type=str,
  36. required=False,
  37. default="localhost:10086",
  38. help="Inference server URL. Default is " "localhost:8001.",
  39. )
  40. parser.add_argument(
  41. "--model_name",
  42. required=False,
  43. default="attention_rescoring",
  44. choices=["attention_rescoring", "streaming_wenet", "infer_pipeline"],
  45. help="the model to send request to",
  46. )
  47. parser.add_argument(
  48. "--wavscp",
  49. type=str,
  50. required=False,
  51. default=None,
  52. help="audio_id \t wav_path",
  53. )
  54. parser.add_argument(
  55. "--trans",
  56. type=str,
  57. required=False,
  58. default=None,
  59. help="audio_id \t text",
  60. )
  61. parser.add_argument(
  62. "--data_dir",
  63. type=str,
  64. required=False,
  65. default=None,
  66. help="path prefix for wav_path in wavscp/audio_file",
  67. )
  68. parser.add_argument(
  69. "--audio_file",
  70. type=str,
  71. required=False,
  72. default=None,
  73. help="single wav file path",
  74. )
  75. # below arguments are for streaming
  76. # Please check onnx_config.yaml and train.yaml
  77. parser.add_argument("--streaming", action="store_true", required=False)
  78. parser.add_argument(
  79. "--sample_rate",
  80. type=int,
  81. required=False,
  82. default=16000,
  83. help="sample rate used in training",
  84. )
  85. parser.add_argument(
  86. "--frame_length_ms",
  87. type=int,
  88. required=False,
  89. default=25,
  90. help="frame length",
  91. )
  92. parser.add_argument(
  93. "--frame_shift_ms",
  94. type=int,
  95. required=False,
  96. default=10,
  97. help="frame shift length",
  98. )
  99. parser.add_argument(
  100. "--chunk_size",
  101. type=int,
  102. required=False,
  103. default=16,
  104. help="chunk size default is 16",
  105. )
  106. parser.add_argument(
  107. "--context",
  108. type=int,
  109. required=False,
  110. default=7,
  111. help="subsampling context",
  112. )
  113. parser.add_argument(
  114. "--subsampling",
  115. type=int,
  116. required=False,
  117. default=4,
  118. help="subsampling rate",
  119. )
  120. FLAGS = parser.parse_args()
  121. print(FLAGS)
  122. # load data
  123. filenames = []
  124. transcripts = []
  125. if FLAGS.audio_file is not None:
  126. path = FLAGS.audio_file
  127. if FLAGS.data_dir:
  128. path = os.path.join(FLAGS.data_dir, path)
  129. if os.path.exists(path):
  130. filenames = [path]
  131. elif FLAGS.wavscp is not None:
  132. audio_data = {}
  133. with open(FLAGS.wavscp, "r", encoding="utf-8") as f:
  134. for line in f:
  135. aid, path = line.strip().split("\t")
  136. if FLAGS.data_dir:
  137. path = os.path.join(FLAGS.data_dir, path)
  138. audio_data[aid] = {"path": path}
  139. with open(FLAGS.trans, "r", encoding="utf-8") as f:
  140. for line in f:
  141. aid, text = line.strip().split("\t")
  142. audio_data[aid]["text"] = text
  143. for key, value in audio_data.items():
  144. filenames.append(value["path"])
  145. transcripts.append(value["text"])
  146. num_workers = multiprocessing.cpu_count() // 2
  147. if FLAGS.streaming:
  148. speech_client_cls = StreamingSpeechClient
  149. else:
  150. speech_client_cls = OfflineSpeechClient
  151. def single_job(client_files):
  152. with grpcclient.InferenceServerClient(
  153. url=FLAGS.url, verbose=FLAGS.verbose
  154. ) as triton_client:
  155. protocol_client = grpcclient
  156. speech_client = speech_client_cls(
  157. triton_client, FLAGS.model_name, protocol_client, FLAGS
  158. )
  159. idx, audio_files = client_files
  160. predictions = []
  161. for li in audio_files:
  162. result = speech_client.recognize(li, idx)
  163. print("Recognized {}:{}".format(li, result[0]))
  164. predictions += result
  165. return predictions
  166. # start to do inference
  167. # Group requests in batches
  168. predictions = []
  169. tasks = []
  170. splits = np.array_split(filenames, num_workers)
  171. for idx, per_split in enumerate(splits):
  172. cur_files = per_split.tolist()
  173. tasks.append((idx, cur_files))
  174. with Pool(processes=num_workers) as pool:
  175. predictions = pool.map(single_job, tasks)
  176. predictions = [item for sublist in predictions for item in sublist]
  177. if transcripts:
  178. cer = cal_cer(predictions, transcripts)
  179. print("CER is: {}".format(cer))